gorm wire

This commit is contained in:
2025-05-07 14:12:53 +08:00
parent 461531c308
commit 68606c76f9
111 changed files with 1726 additions and 5809 deletions

View File

@@ -2,58 +2,136 @@ package repository
import (
"context"
"sync"
"fmt"
"time"
"management/internal/pkg/config"
"github.com/drhin/logger"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
var (
once sync.Once
// 全局变量,方便其它包直接调用已初始化好的 datastore 实例.
engine *datastore
)
type txCtxKey struct{}
type Store interface {
DB(ctx context.Context) *gorm.DB
TX(ctx context.Context, fn func(ctx context.Context) error) error
type Repository struct {
db *gorm.DB
logger *logger.Logger
}
// transactionKey 用于在 context.Context 中存储事务上下文的键.
type transactionKey struct{}
// datastore 是 Storer 的具体实现.
type datastore struct {
core *gorm.DB
func NewRepository(db *gorm.DB, logger *logger.Logger) *Repository {
return &Repository{
db: db,
logger: logger,
}
}
// 确保 datastore 实现了 Storer 接口.
var _ Store = (*datastore)(nil)
type Transaction interface {
Transaction(ctx context.Context, fn func(ctx context.Context) error) error
}
// NewStore 创建一个 Storer 类型的实例.
func NewStore(db *gorm.DB) *datastore {
// 确保 engine 只被初始化一次
once.Do(func() {
engine = &datastore{db}
func NewTransaction(r *Repository) Transaction {
return r
}
func (r *Repository) DB(ctx context.Context) *gorm.DB {
v := ctx.Value(txCtxKey{})
if v != nil {
if tx, ok := v.(*gorm.DB); ok {
return tx
}
}
return r.db.WithContext(ctx)
}
func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
ctx = context.WithValue(ctx, txCtxKey{}, tx)
return fn(ctx)
})
return engine
}
func (store *datastore) DB(ctx context.Context) *gorm.DB {
db := store.core
// 从上下文中提取事务实例
if tx, ok := ctx.Value(transactionKey{}).(*gorm.DB); ok {
db = tx
func NewDB(log *logger.Logger, config *config.Config) (*gorm.DB, func(), error) {
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai",
config.DB.Host,
config.DB.Username,
config.DB.Password,
config.DB.DBName,
config.DB.Port,
)
db, err := gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{})
if err != nil {
return nil, nil, err
}
return db
db = db.Debug()
// Connection Pool config
sqlDB, err := db.DB()
if err != nil {
return nil, nil, err
}
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(100)
sqlDB.SetConnMaxLifetime(time.Hour)
cleanup := func() {
if err := sqlDB.Close(); err != nil {
log.Error("sql db close error", err)
}
}
return db, cleanup, nil
}
func (store *datastore) TX(ctx context.Context, fn func(ctx context.Context) error) error {
return store.core.WithContext(ctx).Transaction(
func(tx *gorm.DB) error {
ctx = context.WithValue(ctx, transactionKey{}, tx)
return fn(ctx)
},
)
}
//var (
// once sync.Once
// // 全局变量,方便其它包直接调用已初始化好的 datastore 实例.
// engine *datastore
//)
//
//type Store interface {
// DB(ctx context.Context) *gorm.DB
// TX(ctx context.Context, fn func(ctx context.Context) error) error
//}
//
//// transactionKey 用于在 context.Context 中存储事务上下文的键.
//type transactionKey struct{}
//
//// datastore 是 Store 的具体实现.
//type datastore struct {
// core *gorm.DB
//}
//
//// NewStore 创建一个 Store 类型的实例.
//func NewStore(db *gorm.DB) Store {
// // 确保 engine 只被初始化一次
// once.Do(func() {
// engine = &datastore{db}
// })
//
// return engine
//}
//
//func (store *datastore) DB(ctx context.Context) *gorm.DB {
// db := store.core
// // 从上下文中提取事务实例
// if tx, ok := ctx.Value(transactionKey{}).(*gorm.DB); ok {
// db = tx
// }
//
// return db
//}
//
//func (store *datastore) TX(ctx context.Context, fn func(ctx context.Context) error) error {
// return store.core.WithContext(ctx).Transaction(
// func(tx *gorm.DB) error {
// ctx = context.WithValue(ctx, transactionKey{}, tx)
// return fn(ctx)
// },
// )
//}

View File

@@ -9,23 +9,21 @@ import (
)
type auditLogRepository struct {
store repository.Store
repo *repository.Repository
}
var _ system.AuditLogRepository = (*auditLogRepository)(nil)
func NewAuditLogRepository(store repository.Store) *auditLogRepository {
func NewAuditLogRepository(repo *repository.Repository) system.AuditLogRepository {
return &auditLogRepository{
store: store,
repo: repo,
}
}
func (s *auditLogRepository) Create(ctx context.Context, obj *system.AuditLog) error {
return s.store.DB(ctx).Create(obj).Error
return s.repo.DB(ctx).Create(obj).Error
}
func (s *auditLogRepository) List(ctx context.Context, q dto.SearchDto) ([]*system.AuditLog, int64, error) {
query := s.store.DB(ctx).
query := s.repo.DB(ctx).
Model(&system.AuditLog{}).
Where("created_at BETWEEN ? AND ?", q.SearchTimeBegin, q.SearchTimeEnd)
if q.SearchEmail != "" {

View File

@@ -9,28 +9,26 @@ import (
)
type configRepository struct {
store repository.Store
repo *repository.Repository
}
var _ system.ConfigRepository = (*configRepository)(nil)
func NewConfigRepository(store repository.Store) *configRepository {
func NewConfigRepository(repo *repository.Repository) system.ConfigRepository {
return &configRepository{
store: store,
repo: repo,
}
}
func (r *configRepository) Create(ctx context.Context, obj *system.Config) error {
return r.store.DB(ctx).Create(obj).Error
return r.repo.DB(ctx).Create(obj).Error
}
func (r *configRepository) Update(ctx context.Context, obj *system.Config) error {
return r.store.DB(ctx).Save(obj).Error
return r.repo.DB(ctx).Save(obj).Error
}
func (r *configRepository) Get(ctx context.Context, id int32) (*system.Config, error) {
var obj system.Config
err := r.store.DB(ctx).First(&obj, id).Error
err := r.repo.DB(ctx).First(&obj, id).Error
if err != nil {
return nil, err
}
@@ -39,7 +37,7 @@ func (r *configRepository) Get(ctx context.Context, id int32) (*system.Config, e
func (r *configRepository) GetByKey(ctx context.Context, key string) (*system.Config, error) {
var obj system.Config
err := r.store.DB(ctx).Where("key = ?", key).First(&obj).Error
err := r.repo.DB(ctx).Where("key = ?", key).First(&obj).Error
if err != nil {
return nil, err
}
@@ -47,5 +45,25 @@ func (r *configRepository) GetByKey(ctx context.Context, key string) (*system.Co
}
func (r *configRepository) List(ctx context.Context, q dto.SearchDto) ([]*system.Config, int64, error) {
return nil, 0, nil
query := r.repo.DB(ctx).
Model(&system.Config{}).
Where("created_at BETWEEN ? AND ?", q.SearchTimeBegin, q.SearchTimeEnd)
var count int64
err := query.Count(&count).Error
if err != nil {
return nil, 0, err
}
var configs []*system.Config
err = query.
Order("id DESC").
Offset((q.Page - 1) * q.Rows).
Limit(q.Rows).
Find(&configs).
Error
if err != nil {
return nil, 0, err
}
return configs, count, nil
}

View File

@@ -9,28 +9,26 @@ import (
)
type departmentRepository struct {
store repository.Store
repo *repository.Repository
}
var _ system.DepartmentRepository = (*departmentRepository)(nil)
func NewDepartmentRepository(store repository.Store) *departmentRepository {
func NewDepartmentRepository(repo *repository.Repository) system.DepartmentRepository {
return &departmentRepository{
store: store,
repo: repo,
}
}
func (r *departmentRepository) Create(ctx context.Context, obj *system.Department) error {
return r.store.DB(ctx).Create(obj).Error
return r.repo.DB(ctx).Create(obj).Error
}
func (r *departmentRepository) Update(ctx context.Context, obj *system.Department) error {
return r.store.DB(ctx).Save(obj).Error
return r.repo.DB(ctx).Save(obj).Error
}
func (r *departmentRepository) Get(ctx context.Context, id int32) (*system.Department, error) {
var obj system.Department
err := r.store.DB(ctx).First(&obj, id).Error
err := r.repo.DB(ctx).First(&obj, id).Error
if err != nil {
return nil, err
}
@@ -39,7 +37,7 @@ func (r *departmentRepository) Get(ctx context.Context, id int32) (*system.Depar
func (r *departmentRepository) All(ctx context.Context) ([]*system.Department, error) {
var departs []*system.Department
err := r.store.DB(ctx).Find(&departs).Error
err := r.repo.DB(ctx).Find(&departs).Error
if err != nil {
return nil, err
}
@@ -47,7 +45,7 @@ func (r *departmentRepository) All(ctx context.Context) ([]*system.Department, e
}
func (r *departmentRepository) List(ctx context.Context, q dto.SearchDto) ([]*system.Department, int64, error) {
query := r.store.DB(ctx).
query := r.repo.DB(ctx).
Model(&system.Department{}).
Where("created_at BETWEEN ? AND ?", q.SearchTimeBegin, q.SearchTimeEnd)
if q.SearchID != 0 {
@@ -96,5 +94,5 @@ SET parent_path = (SELECT ',' || string_agg(cast(t.parent_id AS VARCHAR), ',') |
FROM temp
ORDER BY id) AS t)
WHERE tm.status = 0;`
return r.store.DB(ctx).Exec(query).Error
return r.repo.DB(ctx).Exec(query).Error
}

View File

@@ -9,24 +9,22 @@ import (
)
type loginLogRepository struct {
store repository.Store
repo *repository.Repository
}
var _ system.LoginLogRepository = (*loginLogRepository)(nil)
func NewLoginLogRepository(store repository.Store) *loginLogRepository {
func NewLoginLogRepository(repo *repository.Repository) system.LoginLogRepository {
return &loginLogRepository{
store: store,
repo: repo,
}
}
func (s *loginLogRepository) Create(ctx context.Context, obj *system.LoginLog) error {
return s.store.DB(ctx).Create(obj).Error
return s.repo.DB(ctx).Create(obj).Error
}
func (s *loginLogRepository) GetLatest(ctx context.Context, email string) (*system.LoginLog, error) {
var log system.LoginLog
err := s.store.DB(ctx).
err := s.repo.DB(ctx).
Where("email = ?", email).
Order("id DESC").
First(&log).
@@ -38,7 +36,7 @@ func (s *loginLogRepository) GetLatest(ctx context.Context, email string) (*syst
}
func (s *loginLogRepository) List(ctx context.Context, q dto.SearchDto) ([]*system.LoginLog, int64, error) {
query := s.store.DB(ctx).
query := s.repo.DB(ctx).
Model(&system.LoginLog{}).
Where("created_at BETWEEN ? AND ?", q.SearchTimeBegin, q.SearchTimeEnd)
if q.SearchEmail != "" {
@@ -66,7 +64,7 @@ func (s *loginLogRepository) List(ctx context.Context, q dto.SearchDto) ([]*syst
func (s *loginLogRepository) Count(ctx context.Context, email string) (int64, error) {
var count int64
err := s.store.DB(ctx).
err := s.repo.DB(ctx).
Model(&system.LoginLog{}).
Where("email = ?", email).
Count(&count).

View File

@@ -8,28 +8,26 @@ import (
)
type menuRepository struct {
store repository.Store
repo *repository.Repository
}
var _ system.MenuRepository = (*menuRepository)(nil)
func NewMenuRepository(store repository.Store) *menuRepository {
func NewMenuRepository(repo *repository.Repository) system.MenuRepository {
return &menuRepository{
store: store,
repo: repo,
}
}
func (r *menuRepository) Create(ctx context.Context, obj *system.Menu) error {
return r.store.DB(ctx).Create(obj).Error
return r.repo.DB(ctx).Create(obj).Error
}
func (r *menuRepository) Update(ctx context.Context, obj *system.Menu) error {
return r.store.DB(ctx).Save(obj).Error
return r.repo.DB(ctx).Save(obj).Error
}
func (r *menuRepository) Get(ctx context.Context, id int32) (*system.Menu, error) {
var menu system.Menu
err := r.store.DB(ctx).Where("id = ?", id).First(&menu).Error
err := r.repo.DB(ctx).Where("id = ?", id).First(&menu).Error
if err != nil {
return nil, err
}
@@ -38,7 +36,7 @@ func (r *menuRepository) Get(ctx context.Context, id int32) (*system.Menu, error
func (r *menuRepository) GetByUrl(ctx context.Context, url string) (*system.Menu, error) {
var menu system.Menu
err := r.store.DB(ctx).Where("url = ?", url).First(&menu).Error
err := r.repo.DB(ctx).Where("url = ?", url).First(&menu).Error
if err != nil {
return nil, err
}
@@ -47,7 +45,7 @@ func (r *menuRepository) GetByUrl(ctx context.Context, url string) (*system.Menu
func (r *menuRepository) All(ctx context.Context) ([]*system.Menu, error) {
var menus []*system.Menu
err := r.store.DB(ctx).Find(&menus).Error
err := r.repo.DB(ctx).Find(&menus).Error
if err != nil {
return nil, err
}
@@ -69,5 +67,5 @@ func (r *menuRepository) RebuildParentPath(ctx context.Context) error {
FROM temp
ORDER BY id) AS t)
WHERE tm.status = 0;`
return r.store.DB(ctx).Exec(query).Error
return r.repo.DB(ctx).Exec(query).Error
}

View File

@@ -9,28 +9,26 @@ import (
)
type roleRepository struct {
store repository.Store
repo *repository.Repository
}
var _ system.RoleRepository = (*roleRepository)(nil)
func NewRoleRepository(store repository.Store) *roleRepository {
func NewRoleRepository(repo *repository.Repository) system.RoleRepository {
return &roleRepository{
store: store,
repo: repo,
}
}
func (r *roleRepository) Create(ctx context.Context, obj *system.Role) error {
return r.store.DB(ctx).Create(obj).Error
return r.repo.DB(ctx).Create(obj).Error
}
func (r *roleRepository) Update(ctx context.Context, obj *system.Role) error {
return r.store.DB(ctx).Save(obj).Error
return r.repo.DB(ctx).Save(obj).Error
}
func (r *roleRepository) Get(ctx context.Context, id int32) (*system.Role, error) {
var role system.Role
err := r.store.DB(ctx).Where("id = ?", id).First(&role).Error
err := r.repo.DB(ctx).Where("id = ?", id).First(&role).Error
if err != nil {
return nil, err
}
@@ -39,7 +37,7 @@ func (r *roleRepository) Get(ctx context.Context, id int32) (*system.Role, error
func (r *roleRepository) All(ctx context.Context) ([]*system.Role, error) {
var roles []*system.Role
err := r.store.DB(ctx).Find(&roles).Error
err := r.repo.DB(ctx).Find(&roles).Error
if err != nil {
return nil, err
}
@@ -47,7 +45,7 @@ func (r *roleRepository) All(ctx context.Context) ([]*system.Role, error) {
}
func (r *roleRepository) List(ctx context.Context, q dto.SearchDto) ([]*system.Role, int64, error) {
query := r.store.DB(ctx).
query := r.repo.DB(ctx).
Model(&system.Role{}).
Where("created_at BETWEEN ? AND ?", q.SearchTimeBegin, q.SearchTimeEnd)
if q.SearchID != 0 {
@@ -96,5 +94,5 @@ func (r *roleRepository) RebuildParentPath(ctx context.Context) error {
FROM temp
ORDER BY id) AS t)
WHERE tm.status = 0;`
return r.store.DB(ctx).Exec(query).Error
return r.repo.DB(ctx).Exec(query).Error
}

View File

@@ -8,28 +8,26 @@ import (
)
type roleMenuRepository struct {
store repository.Store
repo *repository.Repository
}
var _ system.RoleMenuRepository = (*roleMenuRepository)(nil)
func NewRoleMenuRepository(store repository.Store) *roleMenuRepository {
func NewRoleMenuRepository(repo *repository.Repository) system.RoleMenuRepository {
return &roleMenuRepository{
store: store,
repo: repo,
}
}
func (r *roleMenuRepository) Create(ctx context.Context, obj []*system.RoleMenu) error {
return r.store.DB(ctx).Create(obj).Error
return r.repo.DB(ctx).Create(obj).Error
}
func (r *roleMenuRepository) DeleteByRoleID(ctx context.Context, roleID int32) error {
return r.store.DB(ctx).Where("role_id = ?", roleID).Delete(&system.RoleMenu{}).Error
return r.repo.DB(ctx).Where("role_id = ?", roleID).Delete(&system.RoleMenu{}).Error
}
func (r *roleMenuRepository) ListByRoleID(ctx context.Context, roleID int32) ([]*system.RoleMenu, error) {
var roleMenus []*system.RoleMenu
err := r.store.DB(ctx).Where("role_id = ?", roleID).Find(&roleMenus).Error
err := r.repo.DB(ctx).Where("role_id = ?", roleID).Find(&roleMenus).Error
if err != nil {
return nil, err
}

View File

@@ -9,28 +9,26 @@ import (
)
type userRepository struct {
store repository.Store
repo *repository.Repository
}
var _ system.UserRepository = (*userRepository)(nil)
func NewUserRepository(store repository.Store) *userRepository {
func NewUserRepository(repo *repository.Repository) system.UserRepository {
return &userRepository{
store: store,
repo: repo,
}
}
func (s *userRepository) Create(ctx context.Context, obj *system.User) error {
return s.store.DB(ctx).Create(obj).Error
return s.repo.DB(ctx).Create(obj).Error
}
func (s *userRepository) Update(ctx context.Context, obj *system.User) error {
return s.store.DB(ctx).Save(obj).Error
return s.repo.DB(ctx).Save(obj).Error
}
func (s *userRepository) Get(ctx context.Context, id int32) (*system.User, error) {
var user system.User
err := s.store.DB(ctx).Where("id = ?", id).First(&user).Error
err := s.repo.DB(ctx).Where("id = ?", id).First(&user).Error
if err != nil {
return nil, err
}
@@ -39,7 +37,7 @@ func (s *userRepository) Get(ctx context.Context, id int32) (*system.User, error
func (s *userRepository) GetByEmail(ctx context.Context, email string) (*system.User, error) {
var user system.User
err := s.store.DB(ctx).Where("email = ?", email).First(&user).Error
err := s.repo.DB(ctx).Where("email = ?", email).First(&user).Error
if err != nil {
return nil, err
}
@@ -48,7 +46,7 @@ func (s *userRepository) GetByEmail(ctx context.Context, email string) (*system.
func (s *userRepository) All(ctx context.Context) ([]*system.User, error) {
var users []*system.User
err := s.store.DB(ctx).Find(&users).Error
err := s.repo.DB(ctx).Find(&users).Error
if err != nil {
return nil, err
}
@@ -56,7 +54,7 @@ func (s *userRepository) All(ctx context.Context) ([]*system.User, error) {
}
func (s *userRepository) List(ctx context.Context, q dto.SearchDto) ([]*system.User, int64, error) {
query := s.store.DB(ctx).
query := s.repo.DB(ctx).
Model(&system.User{}).
Preload("Role").
Preload("Department").