package mid import ( "context" "errors" "net/http" "sync" "time" systemmodel "management/internal/erpserver/model/system" v1 "management/internal/erpserver/service/v1" "management/internal/pkg/know" "management/internal/pkg/session" "github.com/drhin/logger" "go.uber.org/zap" ) // AuditBuffer 审计日志缓冲器 type AuditBuffer struct { auditLogService v1.AuditLogService log *logger.Logger buffer chan *systemmodel.AuditLog stopCh chan struct{} wg sync.WaitGroup batchSize int flushInterval time.Duration } // NewAuditBuffer 创建审计日志缓冲器 func NewAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) *AuditBuffer { return &AuditBuffer{ auditLogService: auditLogService, log: log, buffer: make(chan *systemmodel.AuditLog, 10000), // 缓冲区大小 stopCh: make(chan struct{}), batchSize: 50, // 批量大小 flushInterval: 3 * time.Second, // 刷新间隔 } } // Start 启动缓冲器 func (ab *AuditBuffer) Start() { ab.wg.Add(1) go ab.processBuffer() } // Stop 停止缓冲器 func (ab *AuditBuffer) Stop() { close(ab.stopCh) ab.wg.Wait() close(ab.buffer) } // Add 添加审计日志到缓冲区 func (ab *AuditBuffer) Add(auditLog *systemmodel.AuditLog) { select { case ab.buffer <- auditLog: // 成功添加到缓冲区 default: // 缓冲区满,记录警告但不阻塞 ab.log.Warn("审计日志缓冲区已满,丢弃日志") } } // processBuffer 处理缓冲区中的日志 func (ab *AuditBuffer) processBuffer() { defer ab.wg.Done() ticker := time.NewTicker(ab.flushInterval) defer ticker.Stop() batch := make([]*systemmodel.AuditLog, 0, ab.batchSize) flushBatch := func() { if len(batch) == 0 { return } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // 批量插入 if err := ab.batchInsert(ctx, batch); err != nil { ab.log.Error("批量插入审计日志失败", err, zap.Int("count", len(batch))) } else { ab.log.Debug("批量插入审计日志成功", zap.Int("count", len(batch))) } // 清空批次 batch = batch[:0] } for { select { case <-ab.stopCh: // 停止信号,处理剩余的日志 for len(ab.buffer) > 0 { select { case auditLog := <-ab.buffer: batch = append(batch, auditLog) if len(batch) >= ab.batchSize { flushBatch() } default: break } } flushBatch() // 处理最后一批 return case <-ticker.C: // 定时刷新 flushBatch() case auditLog := <-ab.buffer: // 收到新的审计日志 batch = append(batch, auditLog) if len(batch) >= ab.batchSize { flushBatch() } } } } // batchInsert 批量插入数据库 func (ab *AuditBuffer) batchInsert(ctx context.Context, auditLogs []*systemmodel.AuditLog) error { maxRetries := 3 for i := 0; i < maxRetries; i++ { // 假设你的服务有批量创建方法,如果没有,需要添加 if err := ab.auditLogService.BatchCreate(ctx, auditLogs); err != nil { if i == maxRetries-1 { return err } ab.log.Error("批量插入失败,准备重试", err, zap.Int("retry", i+1)) time.Sleep(time.Duration(i+1) * time.Second) continue } return nil } return nil } // 全局缓冲器实例 var globalAuditBuffer *AuditBuffer // InitAuditBuffer 初始化全局缓冲器 func InitAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) { globalAuditBuffer = NewAuditBuffer(auditLogService, log) globalAuditBuffer.Start() } // StopAuditBuffer 停止全局缓冲器 func StopAuditBuffer() { if globalAuditBuffer != nil { globalAuditBuffer.Stop() } } // Audit 优化后的中间件 func Audit(sess session.Manager, log *logger.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() // 提前获取用户信息 user, err := sess.GetUser(r.Context(), know.StoreName) if err != nil { log.Error("获取用户会话失败", err) next.ServeHTTP(w, r) return } // 处理请求 next.ServeHTTP(w, r) // 异步添加到缓冲区 go func() { if user.ID == 0 { log.Error("用户信息为空", errors.New("user is empty")) return } auditLog := systemmodel.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now()) // 添加到缓冲区,不会阻塞 if globalAuditBuffer != nil { globalAuditBuffer.Add(auditLog) } }() }) } } // 如果你的AuditLogService没有BatchCreate方法,需要添加这个接口 // 在你的service接口中添加: /* type AuditLogService interface { Create(ctx context.Context, auditLog *systemmodel.AuditLog) error BatchCreate(ctx context.Context, auditLogs []*systemmodel.AuditLog) error // ... 其他方法 } */ // 以及对应的实现(PostgreSQL批量插入示例) /* func (s *auditLogService) BatchCreate(ctx context.Context, auditLogs []*systemmodel.AuditLog) error { if len(auditLogs) == 0 { return nil } // 构建批量插入SQL query := `INSERT INTO audit_logs (user_id, email, ip, os, browser, method, url, start_time, end_time, duration) VALUES ` values := make([]interface{}, 0, len(auditLogs)*10) for i, log := range auditLogs { if i > 0 { query += ", " } query += "($" + strconv.Itoa(i*10+1) + ", $" + strconv.Itoa(i*10+2) + ", $" + strconv.Itoa(i*10+3) + ", $" + strconv.Itoa(i*10+4) + ", $" + strconv.Itoa(i*10+5) + ", $" + strconv.Itoa(i*10+6) + ", $" + strconv.Itoa(i*10+7) + ", $" + strconv.Itoa(i*10+8) + ", $" + strconv.Itoa(i*10+9) + ", $" + strconv.Itoa(i*10+10) + ")" values = append(values, log.UserID, log.Email, log.Ip, log.Os, log.Browser, log.Method, log.Url, log.StartTime, log.EndTime, log.Duration) } _, err := s.db.ExecContext(ctx, query, values...) return err } */