sqlx
This commit is contained in:
@@ -1,196 +1,198 @@
|
||||
package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
//import (
|
||||
// "context"
|
||||
// "errors"
|
||||
// "net/http"
|
||||
// "sync"
|
||||
// "time"
|
||||
//
|
||||
// systemmodel "management/internal/erpserver/model/system"
|
||||
// v1 "management/internal/erpserver/service/v1"
|
||||
// "management/internal/pkg/know"
|
||||
// "management/internal/pkg/session"
|
||||
//
|
||||
// "github.com/drhin/logger"
|
||||
// "go.uber.org/zap"
|
||||
//)
|
||||
//
|
||||
//// AuditBuffer 审计日志缓冲器
|
||||
//type AuditBuffer struct {
|
||||
// auditLogService v1.AuditLogService
|
||||
// log *logger.Logger
|
||||
// buffer chan *systemmodel.AuditLog
|
||||
// stopCh chan struct{}
|
||||
// wg sync.WaitGroup
|
||||
// batchSize int
|
||||
// flushInterval time.Duration
|
||||
//}
|
||||
//
|
||||
//// NewAuditBuffer 创建审计日志缓冲器
|
||||
//func NewAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) *AuditBuffer {
|
||||
// return &AuditBuffer{
|
||||
// auditLogService: auditLogService,
|
||||
// log: log,
|
||||
// buffer: make(chan *systemmodel.AuditLog, 10000), // 缓冲区大小
|
||||
// stopCh: make(chan struct{}),
|
||||
// batchSize: 50, // 批量大小
|
||||
// flushInterval: 3 * time.Second, // 刷新间隔
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//// Start 启动缓冲器
|
||||
//func (ab *AuditBuffer) Start() {
|
||||
// ab.wg.Add(1)
|
||||
// go ab.processBuffer()
|
||||
//}
|
||||
//
|
||||
//// Stop 停止缓冲器
|
||||
//func (ab *AuditBuffer) Stop() {
|
||||
// close(ab.stopCh)
|
||||
// ab.wg.Wait()
|
||||
// close(ab.buffer)
|
||||
//}
|
||||
//
|
||||
//// Add 添加审计日志到缓冲区
|
||||
//func (ab *AuditBuffer) Add(auditLog *systemmodel.AuditLog) {
|
||||
// select {
|
||||
// case ab.buffer <- auditLog:
|
||||
// // 成功添加到缓冲区
|
||||
// default:
|
||||
// // 缓冲区满,记录警告但不阻塞
|
||||
// ab.log.Warn("审计日志缓冲区已满,丢弃日志")
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//// processBuffer 处理缓冲区中的日志
|
||||
//func (ab *AuditBuffer) processBuffer() {
|
||||
// defer ab.wg.Done()
|
||||
//
|
||||
// ticker := time.NewTicker(ab.flushInterval)
|
||||
// defer ticker.Stop()
|
||||
//
|
||||
// batch := make([]*systemmodel.AuditLog, 0, ab.batchSize)
|
||||
//
|
||||
// flushBatch := func() {
|
||||
// if len(batch) == 0 {
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// // 批量插入
|
||||
// if err := ab.batchInsert(ctx, batch); err != nil {
|
||||
// ab.log.Error("批量插入审计日志失败", err, zap.Int("count", len(batch)))
|
||||
// } else {
|
||||
// ab.log.Debug("批量插入审计日志成功", zap.Int("count", len(batch)))
|
||||
// }
|
||||
//
|
||||
// // 清空批次
|
||||
// batch = batch[:0]
|
||||
// }
|
||||
//
|
||||
// for {
|
||||
// select {
|
||||
// case <-ab.stopCh:
|
||||
// // 停止信号,处理剩余的日志
|
||||
// for len(ab.buffer) > 0 {
|
||||
// select {
|
||||
// case auditLog := <-ab.buffer:
|
||||
// batch = append(batch, auditLog)
|
||||
// if len(batch) >= ab.batchSize {
|
||||
// flushBatch()
|
||||
// }
|
||||
// default:
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// flushBatch() // 处理最后一批
|
||||
// return
|
||||
//
|
||||
// case <-ticker.C:
|
||||
// // 定时刷新
|
||||
// flushBatch()
|
||||
//
|
||||
// case auditLog := <-ab.buffer:
|
||||
// // 收到新的审计日志
|
||||
// batch = append(batch, auditLog)
|
||||
// if len(batch) >= ab.batchSize {
|
||||
// flushBatch()
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//// batchInsert 批量插入数据库
|
||||
//func (ab *AuditBuffer) batchInsert(ctx context.Context, auditLogs []*systemmodel.AuditLog) error {
|
||||
// maxRetries := 3
|
||||
// for i := 0; i < maxRetries; i++ {
|
||||
// // 假设你的服务有批量创建方法,如果没有,需要添加
|
||||
// if err := ab.auditLogService.BatchCreate(ctx, auditLogs); err != nil {
|
||||
// if i == maxRetries-1 {
|
||||
// return err
|
||||
// }
|
||||
// ab.log.Error("批量插入失败,准备重试", err, zap.Int("retry", i+1))
|
||||
// time.Sleep(time.Duration(i+1) * time.Second)
|
||||
// continue
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
// return nil
|
||||
//}
|
||||
//
|
||||
//// 全局缓冲器实例
|
||||
//var globalAuditBuffer *AuditBuffer
|
||||
//
|
||||
//// InitAuditBuffer 初始化全局缓冲器
|
||||
//func InitAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) {
|
||||
// globalAuditBuffer = NewAuditBuffer(auditLogService, log)
|
||||
// globalAuditBuffer.Start()
|
||||
//}
|
||||
//
|
||||
//// StopAuditBuffer 停止全局缓冲器
|
||||
//func StopAuditBuffer() {
|
||||
// if globalAuditBuffer != nil {
|
||||
// globalAuditBuffer.Stop()
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//// Audit 优化后的中间件
|
||||
//func Audit(sess session.Manager, log *logger.Logger) func(http.Handler) http.Handler {
|
||||
// return func(next http.Handler) http.Handler {
|
||||
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// start := time.Now()
|
||||
//
|
||||
// // 提前获取用户信息
|
||||
// user, err := sess.GetUser(r.Context(), know.StoreName)
|
||||
// if err != nil {
|
||||
// log.Error("获取用户会话失败", err)
|
||||
// next.ServeHTTP(w, r)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// // 处理请求
|
||||
// next.ServeHTTP(w, r)
|
||||
//
|
||||
// // 异步添加到缓冲区
|
||||
// go func() {
|
||||
// if user.ID == 0 {
|
||||
// log.Error("用户信息为空", errors.New("user is empty"))
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// auditLog := systemmodel.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now())
|
||||
//
|
||||
// // 添加到缓冲区,不会阻塞
|
||||
// if globalAuditBuffer != nil {
|
||||
// globalAuditBuffer.Add(auditLog)
|
||||
// }
|
||||
// }()
|
||||
// })
|
||||
// }
|
||||
//}
|
||||
|
||||
systemmodel "management/internal/erpserver/model/system"
|
||||
v1 "management/internal/erpserver/service/v1"
|
||||
"management/internal/pkg/know"
|
||||
"management/internal/pkg/session"
|
||||
|
||||
"github.com/drhin/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AuditBuffer 审计日志缓冲器
|
||||
type AuditBuffer struct {
|
||||
auditLogService v1.AuditLogService
|
||||
log *logger.Logger
|
||||
buffer chan *systemmodel.AuditLog
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
batchSize int
|
||||
flushInterval time.Duration
|
||||
}
|
||||
|
||||
// NewAuditBuffer 创建审计日志缓冲器
|
||||
func NewAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) *AuditBuffer {
|
||||
return &AuditBuffer{
|
||||
auditLogService: auditLogService,
|
||||
log: log,
|
||||
buffer: make(chan *systemmodel.AuditLog, 10000), // 缓冲区大小
|
||||
stopCh: make(chan struct{}),
|
||||
batchSize: 50, // 批量大小
|
||||
flushInterval: 3 * time.Second, // 刷新间隔
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动缓冲器
|
||||
func (ab *AuditBuffer) Start() {
|
||||
ab.wg.Add(1)
|
||||
go ab.processBuffer()
|
||||
}
|
||||
|
||||
// Stop 停止缓冲器
|
||||
func (ab *AuditBuffer) Stop() {
|
||||
close(ab.stopCh)
|
||||
ab.wg.Wait()
|
||||
close(ab.buffer)
|
||||
}
|
||||
|
||||
// Add 添加审计日志到缓冲区
|
||||
func (ab *AuditBuffer) Add(auditLog *systemmodel.AuditLog) {
|
||||
select {
|
||||
case ab.buffer <- auditLog:
|
||||
// 成功添加到缓冲区
|
||||
default:
|
||||
// 缓冲区满,记录警告但不阻塞
|
||||
ab.log.Warn("审计日志缓冲区已满,丢弃日志")
|
||||
}
|
||||
}
|
||||
|
||||
// processBuffer 处理缓冲区中的日志
|
||||
func (ab *AuditBuffer) processBuffer() {
|
||||
defer ab.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(ab.flushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
batch := make([]*systemmodel.AuditLog, 0, ab.batchSize)
|
||||
|
||||
flushBatch := func() {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 批量插入
|
||||
if err := ab.batchInsert(ctx, batch); err != nil {
|
||||
ab.log.Error("批量插入审计日志失败", err, zap.Int("count", len(batch)))
|
||||
} else {
|
||||
ab.log.Debug("批量插入审计日志成功", zap.Int("count", len(batch)))
|
||||
}
|
||||
|
||||
// 清空批次
|
||||
batch = batch[:0]
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ab.stopCh:
|
||||
// 停止信号,处理剩余的日志
|
||||
for len(ab.buffer) > 0 {
|
||||
select {
|
||||
case auditLog := <-ab.buffer:
|
||||
batch = append(batch, auditLog)
|
||||
if len(batch) >= ab.batchSize {
|
||||
flushBatch()
|
||||
}
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
flushBatch() // 处理最后一批
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
// 定时刷新
|
||||
flushBatch()
|
||||
|
||||
case auditLog := <-ab.buffer:
|
||||
// 收到新的审计日志
|
||||
batch = append(batch, auditLog)
|
||||
if len(batch) >= ab.batchSize {
|
||||
flushBatch()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// batchInsert 批量插入数据库
|
||||
func (ab *AuditBuffer) batchInsert(ctx context.Context, auditLogs []*systemmodel.AuditLog) error {
|
||||
maxRetries := 3
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
// 假设你的服务有批量创建方法,如果没有,需要添加
|
||||
if err := ab.auditLogService.BatchCreate(ctx, auditLogs); err != nil {
|
||||
if i == maxRetries-1 {
|
||||
return err
|
||||
}
|
||||
ab.log.Error("批量插入失败,准备重试", err, zap.Int("retry", i+1))
|
||||
time.Sleep(time.Duration(i+1) * time.Second)
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 全局缓冲器实例
|
||||
var globalAuditBuffer *AuditBuffer
|
||||
|
||||
// InitAuditBuffer 初始化全局缓冲器
|
||||
func InitAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) {
|
||||
globalAuditBuffer = NewAuditBuffer(auditLogService, log)
|
||||
globalAuditBuffer.Start()
|
||||
}
|
||||
|
||||
// StopAuditBuffer 停止全局缓冲器
|
||||
func StopAuditBuffer() {
|
||||
if globalAuditBuffer != nil {
|
||||
globalAuditBuffer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// Audit 优化后的中间件
|
||||
func Audit(sess session.Manager, log *logger.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// 提前获取用户信息
|
||||
user, err := sess.GetUser(r.Context(), know.StoreName)
|
||||
if err != nil {
|
||||
log.Error("获取用户会话失败", err)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理请求
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
// 异步添加到缓冲区
|
||||
go func() {
|
||||
if user.ID == 0 {
|
||||
log.Error("用户信息为空", errors.New("user is empty"))
|
||||
return
|
||||
}
|
||||
|
||||
auditLog := systemmodel.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now())
|
||||
|
||||
// 添加到缓冲区,不会阻塞
|
||||
if globalAuditBuffer != nil {
|
||||
globalAuditBuffer.Add(auditLog)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
}
|
||||
// ======================================================
|
||||
|
||||
// 如果你的AuditLogService没有BatchCreate方法,需要添加这个接口
|
||||
// 在你的service接口中添加:
|
||||
|
||||
57
internal/pkg/mid/audit_v3.go
Normal file
57
internal/pkg/mid/audit_v3.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"management/internal/erpserver/model/system"
|
||||
"management/internal/pkg/know"
|
||||
"management/internal/pkg/session"
|
||||
"management/internal/tasks"
|
||||
|
||||
"github.com/drhin/logger"
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Audit 改造后的中间件
|
||||
func Audit(sess session.Manager, log *logger.Logger, task tasks.TaskDistributor) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
start := time.Now()
|
||||
|
||||
user, err := sess.GetUser(ctx, know.StoreName)
|
||||
if err != nil {
|
||||
log.Error("获取用户会话失败", err)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
if user.ID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
payload := &tasks.PayloadConsumeAuditLog{
|
||||
AuditLog: system.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now()),
|
||||
}
|
||||
|
||||
opts := []asynq.Option{
|
||||
asynq.MaxRetry(10),
|
||||
asynq.ProcessIn(1 * time.Second),
|
||||
asynq.Queue(tasks.QueueCritical),
|
||||
}
|
||||
|
||||
c, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := task.DistributeTaskConsumeAuditLog(c, payload, opts...); err != nil {
|
||||
log.Error("distribute task failed", err,
|
||||
zap.String("type", "audit"),
|
||||
zap.Any("payload", payload),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,41 +7,39 @@ import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"management/internal/erpserver/model/dto"
|
||||
v1 "management/internal/erpserver/service/v1"
|
||||
"management/internal/pkg/know"
|
||||
"management/internal/pkg/session"
|
||||
|
||||
"github.com/json-iterator/go"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// 高性能JSON库(全局初始化)
|
||||
var json = jsoniter.ConfigFastest
|
||||
|
||||
// 使用jsoniter优化菜单结构体序列化
|
||||
func init() {
|
||||
jsoniter.RegisterTypeEncoderFunc("dto.OwnerMenuDto",
|
||||
func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
|
||||
m := (*dto.OwnerMenuDto)(ptr)
|
||||
stream.WriteObjectStart()
|
||||
stream.WriteObjectField("id")
|
||||
stream.WriteUint(uint(m.ID))
|
||||
stream.WriteMore()
|
||||
stream.WriteObjectField("url")
|
||||
stream.WriteString(m.Url)
|
||||
stream.WriteMore()
|
||||
stream.WriteObjectField("parentId")
|
||||
stream.WriteUint(uint(m.ParentID))
|
||||
stream.WriteMore()
|
||||
stream.WriteObjectField("isList")
|
||||
stream.WriteBool(m.IsList)
|
||||
stream.WriteObjectEnd()
|
||||
}, nil)
|
||||
}
|
||||
//// 高性能JSON库(全局初始化)
|
||||
//var json = jsoniter.ConfigFastest
|
||||
//
|
||||
//// 使用jsoniter优化菜单结构体序列化
|
||||
//func init() {
|
||||
// jsoniter.RegisterTypeEncoderFunc("dto.OwnerMenuDto",
|
||||
// func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
|
||||
// m := (*dto.OwnerMenuDto)(ptr)
|
||||
// stream.WriteObjectStart()
|
||||
// stream.WriteObjectField("id")
|
||||
// stream.WriteUint(uint(m.ID))
|
||||
// stream.WriteMore()
|
||||
// stream.WriteObjectField("url")
|
||||
// stream.WriteString(m.Url)
|
||||
// stream.WriteMore()
|
||||
// stream.WriteObjectField("parentId")
|
||||
// stream.WriteUint(uint(m.ParentID))
|
||||
// stream.WriteMore()
|
||||
// stream.WriteObjectField("isList")
|
||||
// stream.WriteBool(m.IsList)
|
||||
// stream.WriteObjectEnd()
|
||||
// }, nil)
|
||||
//}
|
||||
|
||||
var publicRoutes = map[string]bool{
|
||||
"/home.html": true,
|
||||
|
||||
95
internal/pkg/mid/authorize_v6.go
Normal file
95
internal/pkg/mid/authorize_v6.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package mid
|
||||
|
||||
//var publicRoutes = map[string]bool{
|
||||
// "/home.html": true,
|
||||
// "/dashboard": true,
|
||||
// "/system/menus": true,
|
||||
// "/upload/img": true,
|
||||
// "/upload/file": true,
|
||||
// "/upload/multi_files": true,
|
||||
// "/system/pear.json": true,
|
||||
// "/logout": true,
|
||||
//}
|
||||
//
|
||||
//var m sync.Map
|
||||
//
|
||||
//func Authorize(
|
||||
// sess session.Manager,
|
||||
// menuService v1.MenuService,
|
||||
//) func(http.Handler) http.Handler {
|
||||
// return func(next http.Handler) http.Handler {
|
||||
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// ctx := r.Context()
|
||||
// path := r.URL.Path
|
||||
//
|
||||
// // 登陆检查
|
||||
// user, err := sess.GetUser(ctx, know.StoreName)
|
||||
// if err != nil || user.ID == 0 {
|
||||
// http.Redirect(w, r, "/", http.StatusFound)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// // 公共路由放行
|
||||
// if publicRoutes[path] {
|
||||
// ctx = setUser(ctx, user)
|
||||
// next.ServeHTTP(w, r.WithContext(ctx))
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// n1 := time.Now()
|
||||
// // 权限检查
|
||||
// var menus map[string]*dto.OwnerMenuDto
|
||||
// cacheKey := fmt.Sprintf("user_menus:%d", user.RoleID)
|
||||
// if value, ok := m.Load(cacheKey); ok {
|
||||
// menus = value.(map[string]*dto.OwnerMenuDto)
|
||||
// log.Printf("map (from cache): %s", time.Since(n1).String())
|
||||
// } else {
|
||||
// menus, err = menuService.ListByRoleIDToMap(ctx, user.RoleID)
|
||||
// if err == nil {
|
||||
// m.Store(cacheKey, menus)
|
||||
// }
|
||||
// log.Printf("map (from DB, then cached): %s", time.Since(n1).String())
|
||||
// }
|
||||
//
|
||||
// if !hasPermission(menus, path) {
|
||||
// http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// cur := getCurrentMenus(menus, path)
|
||||
//
|
||||
// ctx = setUser(ctx, user)
|
||||
// ctx = setCurMenus(ctx, cur)
|
||||
//
|
||||
// next.ServeHTTP(w, r.WithContext(ctx))
|
||||
// })
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool {
|
||||
// _, ok := menus[path]
|
||||
// return ok
|
||||
//}
|
||||
//
|
||||
//func getCurrentMenus(data map[string]*dto.OwnerMenuDto, path string) []dto.OwnerMenuDto {
|
||||
// var res []dto.OwnerMenuDto
|
||||
//
|
||||
// menu, ok := data[path]
|
||||
// if !ok {
|
||||
// return res
|
||||
// }
|
||||
//
|
||||
// for _, item := range data {
|
||||
// if menu.IsList {
|
||||
// if item.ParentID == menu.ID || item.ID == menu.ID {
|
||||
// res = append(res, *item)
|
||||
// }
|
||||
// } else {
|
||||
// if item.ParentID == menu.ParentID {
|
||||
// res = append(res, *item)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return res
|
||||
//}
|
||||
@@ -2,8 +2,10 @@ package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"management/internal/erpserver/model/dto"
|
||||
"management/internal/pkg/sqldb"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
)
|
||||
@@ -73,3 +75,19 @@ func GetHtmlCsrfToken(ctx context.Context) templ.Component {
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
type trkey struct{}
|
||||
|
||||
func setTran(ctx context.Context, tx sqldb.CommitRollbacker) context.Context {
|
||||
return context.WithValue(ctx, trkey{}, tx)
|
||||
}
|
||||
|
||||
// GetTran retrieves the value that can manage a transaction.
|
||||
func GetTran(ctx context.Context) (sqldb.CommitRollbacker, error) {
|
||||
v, ok := ctx.Value(trkey{}).(sqldb.CommitRollbacker)
|
||||
if !ok {
|
||||
return nil, errors.New("transaction not found in context")
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
57
internal/pkg/mid/transaction.go
Normal file
57
internal/pkg/mid/transaction.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package mid
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"management/internal/pkg/sqldb"
|
||||
|
||||
"github.com/drhin/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func BeginCommitRollback(log *logger.Logger, bgn sqldb.Beginner) func(http.Handler) http.Handler {
|
||||
m := func(next http.Handler) http.Handler {
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hasCommitted := false
|
||||
|
||||
log.Info("BEGIN TRANSACTION")
|
||||
tx, err := bgn.Begin()
|
||||
if err != nil {
|
||||
log.Error("BEGIN TRANSACTION", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if !hasCommitted {
|
||||
log.Info("ROLLBACK TRANSACTION")
|
||||
}
|
||||
|
||||
if err := tx.Rollback(); err != nil {
|
||||
if errors.Is(err, sql.ErrTxDone) {
|
||||
return
|
||||
}
|
||||
log.Info("ROLLBACK TRANSACTION", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = setTran(ctx, tx)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
|
||||
log.Info("COMMIT TRANSACTION")
|
||||
if err := tx.Commit(); err != nil {
|
||||
log.Error("COMMIT TRANSACTION", err)
|
||||
return
|
||||
}
|
||||
|
||||
hasCommitted = true
|
||||
})
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
214
internal/pkg/sqldb/sqldb.go
Normal file
214
internal/pkg/sqldb/sqldb.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"management/internal/pkg/config"
|
||||
|
||||
"github.com/drhin/logger"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
uniqueViolation = "23505"
|
||||
undefinedTable = "42P01"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDBNotFound = sql.ErrNoRows
|
||||
ErrDBDuplicatedEntry = errors.New("duplicated entry")
|
||||
ErrUndefinedTable = errors.New("undefined table")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
User string
|
||||
Password string
|
||||
Host string
|
||||
Port int
|
||||
Name string
|
||||
MaxIdleConns int
|
||||
MaxOpenConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func NewDB(config *config.Config, log *logger.Logger) (*sqlx.DB, func(), error) {
|
||||
dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable",
|
||||
config.DB.Username,
|
||||
config.DB.Password,
|
||||
config.DB.Host,
|
||||
config.DB.Port,
|
||||
config.DB.DBName,
|
||||
)
|
||||
|
||||
db, err := sqlx.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("sqlx open db: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 设置最大空闲连接数(默认 2)
|
||||
db.SetMaxIdleConns(config.DB.MaxIdleConns)
|
||||
|
||||
// 设置最大打开连接数(默认 0 无限制)
|
||||
db.SetMaxOpenConns(config.DB.MaxOpenConns)
|
||||
|
||||
// 设置连接最大存活时间
|
||||
db.SetConnMaxLifetime(config.DB.ConnMaxLifetime)
|
||||
|
||||
// 设置连接最大空闲时间
|
||||
db.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime)
|
||||
|
||||
cleanup := func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Error("sql db close error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return db, cleanup, nil
|
||||
}
|
||||
|
||||
func NamedExecContext(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
switch data.(type) {
|
||||
case struct{}:
|
||||
log.Error("database.NamedExecContext (data is struct)", err,
|
||||
zap.String("query", query),
|
||||
zap.Any("ERROR", err))
|
||||
default:
|
||||
log.Error("database.NamedExecContext", err,
|
||||
zap.String("query", query),
|
||||
zap.Any("ERROR", err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := sqlx.NamedExecContext(ctx, db, query, data); err != nil {
|
||||
var pgError *pgconn.PgError
|
||||
if errors.As(err, &pgError) {
|
||||
switch pgError.Code {
|
||||
case undefinedTable:
|
||||
return ErrUndefinedTable
|
||||
case uniqueViolation:
|
||||
return ErrDBDuplicatedEntry
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NamedQueryStruct(ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest any) (err error) {
|
||||
q := queryString(query, data)
|
||||
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
|
||||
if err != nil {
|
||||
var pqErr *pgconn.PgError
|
||||
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
||||
return ErrUndefinedTable
|
||||
}
|
||||
log.Error("NamedQueryStruct NamedQueryContext error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
defer func(rows *sqlx.Rows) {
|
||||
err := rows.Close()
|
||||
if err != nil {
|
||||
log.Error("rows close error", err)
|
||||
}
|
||||
}(rows)
|
||||
|
||||
if !rows.Next() {
|
||||
return ErrDBNotFound
|
||||
}
|
||||
|
||||
if err := rows.StructScan(dest); err != nil {
|
||||
log.Error("NamedQueryStruct StructScan error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NamedQuerySlice[T any](ctx context.Context, log *logger.Logger, db sqlx.ExtContext, query string, data any, dest *[]T) (err error) {
|
||||
q := queryString(query, data)
|
||||
rows, err := sqlx.NamedQueryContext(ctx, db, q, data)
|
||||
if err != nil {
|
||||
var pqErr *pgconn.PgError
|
||||
if errors.As(err, &pqErr) && pqErr.Code == undefinedTable {
|
||||
return ErrUndefinedTable
|
||||
}
|
||||
log.Error("NamedQueryStruct NamedQueryContext error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
defer func(rows *sqlx.Rows) {
|
||||
err := rows.Close()
|
||||
if err != nil {
|
||||
log.Error("rows close error", err)
|
||||
}
|
||||
}(rows)
|
||||
|
||||
var slice []T
|
||||
for rows.Next() {
|
||||
v := new(T)
|
||||
if err := rows.StructScan(v); err != nil {
|
||||
log.Error("NamedQuerySlice StructScan error", err,
|
||||
zap.String("query", q),
|
||||
zap.Any("data", data),
|
||||
)
|
||||
return err
|
||||
}
|
||||
slice = append(slice, *v)
|
||||
}
|
||||
*dest = slice
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func queryString(query string, args any) string {
|
||||
query, params, err := sqlx.Named(query, args)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
for _, param := range params {
|
||||
var value string
|
||||
switch v := param.(type) {
|
||||
case string:
|
||||
value = fmt.Sprintf("'%s'", v)
|
||||
case []byte:
|
||||
value = fmt.Sprintf("'%s'", string(v))
|
||||
default:
|
||||
value = fmt.Sprintf("%v", v)
|
||||
}
|
||||
query = strings.Replace(query, "?", value, 1)
|
||||
}
|
||||
|
||||
query = strings.ReplaceAll(query, "\t", "")
|
||||
query = strings.ReplaceAll(query, "\n", " ")
|
||||
|
||||
return strings.Trim(query, " ")
|
||||
}
|
||||
49
internal/pkg/sqldb/tran.go
Normal file
49
internal/pkg/sqldb/tran.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// Beginner represents a value that can begin a transaction.
|
||||
type Beginner interface {
|
||||
Begin() (CommitRollbacker, error)
|
||||
}
|
||||
|
||||
// CommitRollbacker represents a value that can commit or rollback a transaction.
|
||||
type CommitRollbacker interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
// DBBeginner implements the Beginner interface,
|
||||
type DBBeginner struct {
|
||||
sqlxDB *sqlx.DB
|
||||
}
|
||||
|
||||
// NewBeginner constructs a value that implements the beginner interface.
|
||||
func NewBeginner(sqlxDB *sqlx.DB) *DBBeginner {
|
||||
return &DBBeginner{
|
||||
sqlxDB: sqlxDB,
|
||||
}
|
||||
}
|
||||
|
||||
// Begin implements the Beginner interface and returns a concrete value that
|
||||
// implements the CommitRollbacker interface.
|
||||
func (db *DBBeginner) Begin() (CommitRollbacker, error) {
|
||||
return db.sqlxDB.Beginx()
|
||||
}
|
||||
|
||||
// GetExtContext is a helper function that extracts the sqlx value
|
||||
// from the domain transactor interface for transactional use.
|
||||
func GetExtContext(tx CommitRollbacker) (sqlx.ExtContext, error) {
|
||||
ec, ok := tx.(sqlx.ExtContext)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Transactor(%T) not of a type *sql.Tx", tx)
|
||||
}
|
||||
|
||||
return ec, nil
|
||||
}
|
||||
Reference in New Issue
Block a user