This commit is contained in:
2025-06-13 17:23:16 +08:00
parent 3150ba80bc
commit 1b72f51e4a
55 changed files with 3894 additions and 310 deletions

19
internal/pkg/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,19 @@
package cache
import (
"context"
"time"
"github.com/redis/go-redis/v9"
)
type Cache interface {
Encode(a any) ([]byte, error)
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error
Del(ctx context.Context, keys ...string) error
Get(ctx context.Context, key string) (string, error)
GetBytes(ctx context.Context, key string) ([]byte, error)
Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd
Keys(ctx context.Context, pattern string) ([]string, error)
ListKeys(ctx context.Context, pattern string, pageID int, pageSize int) ([]string, int, error)
}

View File

@@ -1,4 +1,4 @@
package redis
package cache
import (
"bytes"
@@ -16,28 +16,17 @@ import (
var ErrRedisKeyNotFound = errors.New("redis key not found")
type Cache interface {
Encode(a any) ([]byte, error)
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error
Del(ctx context.Context, keys ...string) error
Get(ctx context.Context, key string) (string, error)
GetBytes(ctx context.Context, key string) ([]byte, error)
Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd
Keys(ctx context.Context, pattern string) ([]string, error)
ListKeys(ctx context.Context, pattern string, pageID int, pageSize int) ([]string, int, error)
}
type redisCache struct {
client *redis.Client
}
func New(conf *config.Config, log *logger.Logger) (Cache, func(), error) {
func ConnectRedis(conf *config.Config, log *logger.Logger) (*redis.Client, func(), error) {
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", conf.Redis.Host, conf.Redis.Port),
Password: conf.Redis.Password,
DB: conf.Redis.DB,
})
_, err := rdb.Ping(context.Background()).Result()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
_, err := rdb.Ping(ctx).Result()
if err != nil {
return nil, nil, err
}
@@ -48,9 +37,17 @@ func New(conf *config.Config, log *logger.Logger) (Cache, func(), error) {
}
}
return rdb, cleanup, nil
}
type redisCache struct {
client *redis.Client
}
func NewRedisCache(client *redis.Client) Cache {
return &redisCache{
client: rdb,
}, cleanup, nil
client: client,
}
}
func (r *redisCache) Encode(a any) ([]byte, error) {

View File

@@ -0,0 +1,58 @@
package mid
//import (
// "context"
// "errors"
// "net/http"
// "time"
//
// systemmodel "management/internal/erpserver/model/system"
// v1 "management/internal/erpserver/service/v1"
// "management/internal/pkg/know"
// "management/internal/pkg/session"
//
// "github.com/drhin/logger"
// "go.uber.org/zap"
//)
//
//func Audit(sess session.Manager, auditLogService v1.AuditLogService, log *logger.Logger) func(http.Handler) http.Handler {
// return func(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// start := time.Now()
//
// // 提前获取用户信息(同步操作)
// user, err := sess.GetUser(r.Context(), know.StoreName)
// if err != nil {
// log.Error("获取用户会话失败", err)
// next.ServeHTTP(w, r) // 继续处理请求
// return
// }
//
// defer func() {
// go func() {
// if user.ID == 0 {
// log.Error("用户信息为空", errors.New("scs get user is empty"))
// return
// }
//
// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
// defer cancel()
//
// al := systemmodel.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now())
// if err := auditLogService.Create(ctx, al); err != nil {
// log.Error(err.Error(), err,
// zap.Int32("user_id", user.ID),
// zap.String("user", user.Email),
// zap.String("ip", al.Ip),
// zap.String("os", al.Os),
// zap.String("method", al.Method),
// zap.String("path", al.Url),
// )
// }
// }()
// }()
//
// next.ServeHTTP(w, r)
// })
// }
//}

View File

@@ -0,0 +1,228 @@
package mid
import (
"context"
"errors"
"net/http"
"sync"
"time"
systemmodel "management/internal/erpserver/model/system"
v1 "management/internal/erpserver/service/v1"
"management/internal/pkg/know"
"management/internal/pkg/session"
"github.com/drhin/logger"
"go.uber.org/zap"
)
// AuditBuffer 审计日志缓冲器
type AuditBuffer struct {
auditLogService v1.AuditLogService
log *logger.Logger
buffer chan *systemmodel.AuditLog
stopCh chan struct{}
wg sync.WaitGroup
batchSize int
flushInterval time.Duration
}
// NewAuditBuffer 创建审计日志缓冲器
func NewAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) *AuditBuffer {
return &AuditBuffer{
auditLogService: auditLogService,
log: log,
buffer: make(chan *systemmodel.AuditLog, 10000), // 缓冲区大小
stopCh: make(chan struct{}),
batchSize: 50, // 批量大小
flushInterval: 3 * time.Second, // 刷新间隔
}
}
// Start 启动缓冲器
func (ab *AuditBuffer) Start() {
ab.wg.Add(1)
go ab.processBuffer()
}
// Stop 停止缓冲器
func (ab *AuditBuffer) Stop() {
close(ab.stopCh)
ab.wg.Wait()
close(ab.buffer)
}
// Add 添加审计日志到缓冲区
func (ab *AuditBuffer) Add(auditLog *systemmodel.AuditLog) {
select {
case ab.buffer <- auditLog:
// 成功添加到缓冲区
default:
// 缓冲区满,记录警告但不阻塞
ab.log.Warn("审计日志缓冲区已满,丢弃日志")
}
}
// processBuffer 处理缓冲区中的日志
func (ab *AuditBuffer) processBuffer() {
defer ab.wg.Done()
ticker := time.NewTicker(ab.flushInterval)
defer ticker.Stop()
batch := make([]*systemmodel.AuditLog, 0, ab.batchSize)
flushBatch := func() {
if len(batch) == 0 {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 批量插入
if err := ab.batchInsert(ctx, batch); err != nil {
ab.log.Error("批量插入审计日志失败", err, zap.Int("count", len(batch)))
} else {
ab.log.Debug("批量插入审计日志成功", zap.Int("count", len(batch)))
}
// 清空批次
batch = batch[:0]
}
for {
select {
case <-ab.stopCh:
// 停止信号,处理剩余的日志
for len(ab.buffer) > 0 {
select {
case auditLog := <-ab.buffer:
batch = append(batch, auditLog)
if len(batch) >= ab.batchSize {
flushBatch()
}
default:
break
}
}
flushBatch() // 处理最后一批
return
case <-ticker.C:
// 定时刷新
flushBatch()
case auditLog := <-ab.buffer:
// 收到新的审计日志
batch = append(batch, auditLog)
if len(batch) >= ab.batchSize {
flushBatch()
}
}
}
}
// batchInsert 批量插入数据库
func (ab *AuditBuffer) batchInsert(ctx context.Context, auditLogs []*systemmodel.AuditLog) error {
maxRetries := 3
for i := 0; i < maxRetries; i++ {
// 假设你的服务有批量创建方法,如果没有,需要添加
if err := ab.auditLogService.BatchCreate(ctx, auditLogs); err != nil {
if i == maxRetries-1 {
return err
}
ab.log.Error("批量插入失败,准备重试", err, zap.Int("retry", i+1))
time.Sleep(time.Duration(i+1) * time.Second)
continue
}
return nil
}
return nil
}
// 全局缓冲器实例
var globalAuditBuffer *AuditBuffer
// InitAuditBuffer 初始化全局缓冲器
func InitAuditBuffer(auditLogService v1.AuditLogService, log *logger.Logger) {
globalAuditBuffer = NewAuditBuffer(auditLogService, log)
globalAuditBuffer.Start()
}
// StopAuditBuffer 停止全局缓冲器
func StopAuditBuffer() {
if globalAuditBuffer != nil {
globalAuditBuffer.Stop()
}
}
// Audit 优化后的中间件
func Audit(sess session.Manager, log *logger.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 提前获取用户信息
user, err := sess.GetUser(r.Context(), know.StoreName)
if err != nil {
log.Error("获取用户会话失败", err)
next.ServeHTTP(w, r)
return
}
// 处理请求
next.ServeHTTP(w, r)
// 异步添加到缓冲区
go func() {
if user.ID == 0 {
log.Error("用户信息为空", errors.New("user is empty"))
return
}
auditLog := systemmodel.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now())
// 添加到缓冲区,不会阻塞
if globalAuditBuffer != nil {
globalAuditBuffer.Add(auditLog)
}
}()
})
}
}
// 如果你的AuditLogService没有BatchCreate方法需要添加这个接口
// 在你的service接口中添加
/*
type AuditLogService interface {
Create(ctx context.Context, auditLog *systemmodel.AuditLog) error
BatchCreate(ctx context.Context, auditLogs []*systemmodel.AuditLog) error
// ... 其他方法
}
*/
// 以及对应的实现PostgreSQL批量插入示例
/*
func (s *auditLogService) BatchCreate(ctx context.Context, auditLogs []*systemmodel.AuditLog) error {
if len(auditLogs) == 0 {
return nil
}
// 构建批量插入SQL
query := `INSERT INTO audit_logs (user_id, email, ip, os, browser, method, url, start_time, end_time, duration) VALUES `
values := make([]interface{}, 0, len(auditLogs)*10)
for i, log := range auditLogs {
if i > 0 {
query += ", "
}
query += "($" + strconv.Itoa(i*10+1) + ", $" + strconv.Itoa(i*10+2) + ", $" + strconv.Itoa(i*10+3) + ", $" + strconv.Itoa(i*10+4) + ", $" + strconv.Itoa(i*10+5) + ", $" + strconv.Itoa(i*10+6) + ", $" + strconv.Itoa(i*10+7) + ", $" + strconv.Itoa(i*10+8) + ", $" + strconv.Itoa(i*10+9) + ", $" + strconv.Itoa(i*10+10) + ")"
values = append(values, log.UserID, log.Email, log.Ip, log.Os, log.Browser, log.Method, log.Url, log.StartTime, log.EndTime, log.Duration)
}
_, err := s.db.ExecContext(ctx, query, values...)
return err
}
*/

View File

@@ -0,0 +1,99 @@
package mid
//import (
// "log"
// "net/http"
// "time"
//
// "management/internal/erpserver/model/dto"
// v1 "management/internal/erpserver/service/v1"
// "management/internal/pkg/know"
// "management/internal/pkg/session"
//)
//
//var publicRoutes = map[string]bool{
// "/home.html": true,
// "/dashboard": true,
// "/system/menus": true,
// "/upload/img": true,
// "/upload/file": true,
// "/upload/multi_files": true,
// "/pear.json": true,
// "/logout": true,
//}
//
//func Authorize(
// sess session.Manager,
// menuService v1.MenuService,
//) func(http.Handler) http.Handler {
// return func(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ctx := r.Context()
// path := r.URL.Path
//
// // 登陆检查
// n := time.Now()
// user, err := sess.GetUser(ctx, know.StoreName)
// if err != nil || user.ID == 0 {
// http.Redirect(w, r, "/", http.StatusFound)
// return
// }
//
// log.Printf("scs get user: %s", time.Since(n).String())
//
// // 公共路由放行
// if publicRoutes[path] {
// ctx = setUser(ctx, user)
// next.ServeHTTP(w, r.WithContext(ctx))
// return
// }
//
// n1 := time.Now()
// // 权限检查
// menus, err := menuService.ListByRoleIDToMap(ctx, user.RoleID)
// if err != nil || !hasPermission(menus, path) {
// http.Error(w, "Forbidden", http.StatusForbidden)
// return
// }
//
// log.Printf("listByRoleIDToMap: %s", time.Since(n1).String())
//
// n2 := time.Now()
// cur := getCurrentMenus(menus, path)
// log.Printf("getCurrentMenus: %s", time.Since(n2).String())
//
// ctx = setUser(ctx, user)
// ctx = setCurMenus(ctx, cur)
//
// next.ServeHTTP(w, r.WithContext(ctx))
// })
// }
//}
//
//func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool {
// _, ok := menus[path]
// return ok
//}
//
//func getCurrentMenus(data map[string]*dto.OwnerMenuDto, path string) []dto.OwnerMenuDto {
// var res []dto.OwnerMenuDto
//
// menu, ok := data[path]
// if !ok {
// return res
// }
//
// for _, item := range data {
// if menu.IsList {
// if item.ParentID == menu.ID || item.ID == menu.ID {
// res = append(res, *item)
// }
// } else {
// if item.ParentID == menu.ParentID {
// res = append(res, *item)
// }
// }
// }
//
// return res
//}

View File

@@ -0,0 +1,402 @@
package mid
//
//import (
// "context"
// "log"
// "net/http"
// "sync"
// "time"
//
// "management/internal/erpserver/model/dto"
// v1 "management/internal/erpserver/service/v1"
// "management/internal/pkg/know"
// "management/internal/pkg/session"
//)
//
//var publicRoutes = map[string]bool{
// "/home.html": true,
// "/dashboard": true,
// "/system/menus": true,
// "/upload/img": true,
// "/upload/file": true,
// "/upload/multi_files": true,
// "/pear.json": true,
// "/logout": true,
//}
//
//// MenuCacheItem 菜单缓存项
//type MenuCacheItem struct {
// Data map[string]*dto.OwnerMenuDto
// ExpireAt time.Time
// LoadTime time.Time
// Version int64
//}
//
//// IsExpired 检查是否过期
//func (item *MenuCacheItem) IsExpired() bool {
// return time.Now().After(item.ExpireAt)
//}
//
//// MenuCache 内存缓存管理器
//type MenuCache struct {
// mu sync.RWMutex
// cache map[int32]*MenuCacheItem // roleID -> MenuCacheItem
// maxSize int // 最大缓存条目数
// ttl time.Duration // 缓存TTL
// refreshTTL time.Duration // 刷新TTL (提前刷新时间)
// stats CacheStats // 缓存统计
// cleanupTick *time.Ticker // 清理定时器
// stopCh chan struct{} // 停止信号
//}
//
//// CacheStats 缓存统计信息
//type CacheStats struct {
// mu sync.RWMutex
// Hits int64
// Misses int64
// Evictions int64
// RefreshCount int64
//}
//
//// NewMenuCache 创建新的菜单缓存
//func NewMenuCache(maxSize int, ttl time.Duration) *MenuCache {
// cache := &MenuCache{
// cache: make(map[int32]*MenuCacheItem),
// maxSize: maxSize,
// ttl: ttl,
// refreshTTL: ttl - time.Duration(float64(ttl)*0.1), // 提前10%刷新
// stopCh: make(chan struct{}),
// cleanupTick: time.NewTicker(time.Minute * 5), // 每5分钟清理一次
// }
//
// // 启动后台清理协程
// go cache.cleanup()
//
// return cache
//}
//
//// Get 获取缓存数据
//func (mc *MenuCache) Get(roleID int32) (map[string]*dto.OwnerMenuDto, bool) {
// mc.mu.RLock()
// item, exists := mc.cache[roleID]
// mc.mu.RUnlock()
//
// if !exists {
// mc.recordMiss()
// return nil, false
// }
//
// // 检查是否过期
// if item.IsExpired() {
// mc.recordMiss()
// // 异步删除过期项
// go func() {
// mc.mu.Lock()
// delete(mc.cache, roleID)
// mc.mu.Unlock()
// }()
// return nil, false
// }
//
// mc.recordHit()
// return item.Data, true
//}
//
//// Set 设置缓存数据
//func (mc *MenuCache) Set(roleID int32, data map[string]*dto.OwnerMenuDto, version int64) {
// mc.mu.Lock()
// defer mc.mu.Unlock()
//
// // 如果缓存已满执行LRU淘汰
// if len(mc.cache) >= mc.maxSize {
// mc.evictLRU()
// }
//
// item := &MenuCacheItem{
// Data: data,
// ExpireAt: time.Now().Add(mc.ttl),
// LoadTime: time.Now(),
// Version: version,
// }
//
// mc.cache[roleID] = item
//}
//
//// NeedRefresh 检查是否需要提前刷新
//func (mc *MenuCache) NeedRefresh(roleID int32) bool {
// mc.mu.RLock()
// item, exists := mc.cache[roleID]
// mc.mu.RUnlock()
//
// if !exists {
// return true
// }
//
// // 提前刷新策略在过期前10%时间开始刷新
// refreshTime := item.LoadTime.Add(mc.refreshTTL)
// return time.Now().After(refreshTime)
//}
//
//// evictLRU 淘汰最久未使用的缓存项
//func (mc *MenuCache) evictLRU() {
// var oldestRoleID int32
// var oldestTime time.Time = time.Now()
//
// for roleID, item := range mc.cache {
// if item.LoadTime.Before(oldestTime) {
// oldestTime = item.LoadTime
// oldestRoleID = roleID
// }
// }
//
// if oldestRoleID != 0 {
// delete(mc.cache, oldestRoleID)
// mc.stats.mu.Lock()
// mc.stats.Evictions++
// mc.stats.mu.Unlock()
// }
//}
//
//// cleanup 后台清理过期缓存
//func (mc *MenuCache) cleanup() {
// for {
// select {
// case <-mc.cleanupTick.C:
// mc.cleanupExpired()
// case <-mc.stopCh:
// mc.cleanupTick.Stop()
// return
// }
// }
//}
//
//// cleanupExpired 清理过期缓存
//func (mc *MenuCache) cleanupExpired() {
// mc.mu.Lock()
// defer mc.mu.Unlock()
//
// now := time.Now()
// for roleID, item := range mc.cache {
// if now.After(item.ExpireAt) {
// delete(mc.cache, roleID)
// }
// }
//}
//
//// GetStats 获取缓存统计信息
//func (mc *MenuCache) GetStats() CacheStats {
// mc.stats.mu.RLock()
// defer mc.stats.mu.RUnlock()
// return mc.stats
//}
//
//// recordHit 记录缓存命中
//func (mc *MenuCache) recordHit() {
// mc.stats.mu.Lock()
// mc.stats.Hits++
// mc.stats.mu.Unlock()
//}
//
//// recordMiss 记录缓存未命中
//func (mc *MenuCache) recordMiss() {
// mc.stats.mu.Lock()
// mc.stats.Misses++
// mc.stats.mu.Unlock()
//}
//
//// Close 关闭缓存
//func (mc *MenuCache) Close() {
// close(mc.stopCh)
//}
//
//// CachedMenuService 带缓存的菜单服务包装器
//type CachedMenuService struct {
// menuService v1.MenuService
// cache *MenuCache
// mu sync.RWMutex
// refreshing map[int32]bool // 正在刷新的roleID
//}
//
//// NewCachedMenuService 创建带缓存的菜单服务
//func NewCachedMenuService(menuService v1.MenuService, cache *MenuCache) *CachedMenuService {
// return &CachedMenuService{
// menuService: menuService,
// cache: cache,
// refreshing: make(map[int32]bool),
// }
//}
//
//// ListByRoleIDToMap 获取菜单数据(带缓存)
//func (cms *CachedMenuService) ListByRoleIDToMap(ctx context.Context, roleID int32) (map[string]*dto.OwnerMenuDto, error) {
// // 先尝试从缓存获取
// if data, hit := cms.cache.Get(roleID); hit {
// // 检查是否需要异步刷新
// if cms.cache.NeedRefresh(roleID) {
// go cms.asyncRefresh(ctx, roleID)
// }
// return data, nil
// }
//
// // 缓存未命中,同步获取数据
// return cms.loadAndCache(ctx, roleID)
//}
//
//// loadAndCache 加载数据并缓存
//func (cms *CachedMenuService) loadAndCache(ctx context.Context, roleID int32) (map[string]*dto.OwnerMenuDto, error) {
// // 防止并发重复加载
// cms.mu.Lock()
// if cms.refreshing[roleID] {
// cms.mu.Unlock()
// // 如果正在加载,等待一小段时间后重试缓存
// time.Sleep(time.Millisecond * 10)
// if data, hit := cms.cache.Get(roleID); hit {
// return data, nil
// }
// // 重试失败,继续执行加载逻辑
// cms.mu.Lock()
// }
// cms.refreshing[roleID] = true
// cms.mu.Unlock()
//
// defer func() {
// cms.mu.Lock()
// delete(cms.refreshing, roleID)
// cms.mu.Unlock()
// }()
//
// // 从原始服务获取数据
// data, err := cms.menuService.ListByRoleIDToMap(ctx, roleID)
// if err != nil {
// return nil, err
// }
//
// // 缓存数据
// version := time.Now().UnixNano()
// cms.cache.Set(roleID, data, version)
//
// return data, nil
//}
//
//// asyncRefresh 异步刷新缓存
//func (cms *CachedMenuService) asyncRefresh(ctx context.Context, roleID int32) {
// // 使用背景上下文,避免原请求取消影响刷新
// bgCtx := context.Background()
//
// cms.mu.RLock()
// if cms.refreshing[roleID] {
// cms.mu.RUnlock()
// return
// }
// cms.mu.RUnlock()
//
// _, err := cms.loadAndCache(bgCtx, roleID)
// if err != nil {
// log.Printf("async refresh menu cache failed for roleID %d: %v", roleID, err)
// }
//
// cms.cache.stats.mu.Lock()
// cms.cache.stats.RefreshCount++
// cms.cache.stats.mu.Unlock()
//}
//
//// 全局缓存实例
//var (
// menuCache *MenuCache
// cachedMenuSvc *CachedMenuService
// cacheInitOnce sync.Once
//)
//
//// InitMenuCache 初始化菜单缓存(在应用启动时调用)
//func InitMenuCache(menuService v1.MenuService) {
// cacheInitOnce.Do(func() {
// // 配置参数最大1000个角色的缓存TTL 10分钟
// menuCache = NewMenuCache(1000, time.Minute*10)
// cachedMenuSvc = NewCachedMenuService(menuService, menuCache)
// })
//}
//
//// GetCachedMenuService 获取缓存菜单服务实例
//func GetCachedMenuService() *CachedMenuService {
// return cachedMenuSvc
//}
//
//// Authorize 修改后的授权中间件
//func Authorize(
// sess session.Manager,
// menuService v1.MenuService,
//) func(http.Handler) http.Handler {
// // 初始化缓存
// InitMenuCache(menuService)
//
// return func(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ctx := r.Context()
// path := r.URL.Path
//
// // 登陆检查
// n := time.Now()
// user, err := sess.GetUser(ctx, know.StoreName)
// if err != nil || user.ID == 0 {
// http.Redirect(w, r, "/", http.StatusFound)
// return
// }
//
// log.Printf("scs get user: %s", time.Since(n).String())
//
// // 公共路由放行
// if publicRoutes[path] {
// ctx = setUser(ctx, user)
// next.ServeHTTP(w, r.WithContext(ctx))
// return
// }
//
// n1 := time.Now()
// // 权限检查 - 使用缓存服务
// menus, err := GetCachedMenuService().ListByRoleIDToMap(ctx, user.RoleID)
// if err != nil || !hasPermission(menus, path) {
// http.Error(w, "Forbidden", http.StatusForbidden)
// return
// }
//
// log.Printf("listByRoleIDToMap (cached): %s", time.Since(n1).String())
//
// n2 := time.Now()
// cur := getCurrentMenus(menus, path)
// log.Printf("getCurrentMenus: %s", time.Since(n2).String())
//
// ctx = setUser(ctx, user)
// ctx = setCurMenus(ctx, cur)
//
// next.ServeHTTP(w, r.WithContext(ctx))
// })
// }
//}
//
//func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool {
// _, ok := menus[path]
// return ok
//}
//
//func getCurrentMenus(data map[string]*dto.OwnerMenuDto, path string) []dto.OwnerMenuDto {
// var res []dto.OwnerMenuDto
//
// menu, ok := data[path]
// if !ok {
// return res
// }
//
// for _, item := range data {
// if menu.IsList {
// if item.ParentID == menu.ID || item.ID == menu.ID {
// res = append(res, *item)
// }
// } else {
// if item.ParentID == menu.ParentID {
// res = append(res, *item)
// }
// }
// }
//
// return res
//}

View File

@@ -0,0 +1,537 @@
package mid
//import (
// "context"
// "encoding/json"
// "errors"
// "fmt"
// "log"
// "net/http"
// "strconv"
// "sync"
// "time"
//
// "management/internal/erpserver/model/dto"
// v1 "management/internal/erpserver/service/v1"
// "management/internal/pkg/know"
// "management/internal/pkg/session"
//
// "github.com/allegro/bigcache/v3"
//)
//
//var publicRoutes = map[string]bool{
// "/home.html": true,
// "/dashboard": true,
// "/system/menus": true,
// "/upload/img": true,
// "/upload/file": true,
// "/upload/multi_files": true,
// "/pear.json": true,
// "/logout": true,
//}
//
//// MenuCacheEntry 菜单缓存条目
//type MenuCacheEntry struct {
// Data map[string]*dto.OwnerMenuDto `json:"data"`
// Timestamp int64 `json:"timestamp"`
// Version int64 `json:"version"`
//}
//
//// CacheStats 缓存统计信息
//type CacheStats struct {
// mu sync.RWMutex
// Hits int64
// Misses int64
// Errors int64
// RefreshCount int64
// AsyncRefreshHits int64
// LastRefreshTime time.Time
//}
//
//// GetHitRate 获取命中率
//func (cs *CacheStats) GetHitRate() float64 {
// cs.mu.RLock()
// defer cs.mu.RUnlock()
//
// total := cs.Hits + cs.Misses
// if total == 0 {
// return 0
// }
// return float64(cs.Hits) / float64(total) * 100
//}
//
//// MenuCacheManager BigCache菜单缓存管理器
//type MenuCacheManager struct {
// cache *bigcache.BigCache
// refreshTTL time.Duration
// stats *CacheStats
// mu sync.RWMutex
// refreshing map[int64]bool
// stopCh chan struct{}
// monitorTicker *time.Ticker
//}
//
//// NewMenuCacheManager 创建菜单缓存管理器
//func NewMenuCacheManager(ttl time.Duration) (*MenuCacheManager, error) {
// config := bigcache.DefaultConfig(ttl)
//
// // 生产环境优化配置
// config.Shards = 256 // 分片数,减少锁竞争
// config.MaxEntriesInWindow = 10000 // 窗口内最大条目数
// config.MaxEntrySize = 1024 * 50 // 最大条目50KB
// config.HardMaxCacheSize = 512 // 最大缓存512MB
// config.StatsEnabled = true // 启用统计
// config.Verbose = false // 关闭详细日志
// config.CleanWindow = 5 * time.Minute // 清理窗口
//
// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
// defer cancel()
//
// cache, err := bigcache.New(ctx, config)
// if err != nil {
// return nil, fmt.Errorf("failed to create BigCache: %w", err)
// }
//
// manager := &MenuCacheManager{
// cache: cache,
// refreshTTL: time.Duration(float64(ttl) * 0.8), // 80%时间后开始刷新
// stats: &CacheStats{},
// refreshing: make(map[int64]bool),
// stopCh: make(chan struct{}),
// monitorTicker: time.NewTicker(time.Minute), // 每分钟监控一次
// }
//
// // 启动监控协程
// go manager.monitor()
//
// return manager, nil
//}
//
//// Get 获取缓存数据
//func (mcm *MenuCacheManager) Get(roleID int32) (map[string]*dto.OwnerMenuDto, bool, bool) {
// key := mcm.makeKey(roleID)
//
// data, err := mcm.cache.Get(key)
// if err != nil {
// if !errors.Is(err, bigcache.ErrEntryNotFound) {
// mcm.recordError()
// log.Printf("BigCache get error for roleID %d: %v", roleID, err)
// }
// mcm.recordMiss()
// return nil, false, false
// }
//
// // 反序列化
// var entry MenuCacheEntry
// if err := json.Unmarshal(data, &entry); err != nil {
// mcm.recordError()
// log.Printf("Failed to unmarshal cache entry for roleID %d: %v", roleID, err)
// mcm.recordMiss()
// return nil, false, false
// }
//
// mcm.recordHit()
//
// // 检查是否需要刷新
// needRefresh := time.Since(time.Unix(entry.Timestamp, 0)) > mcm.refreshTTL
//
// return entry.Data, true, needRefresh
//}
//
//// Set 设置缓存数据
//func (mcm *MenuCacheManager) Set(roleID int32, data map[string]*dto.OwnerMenuDto) error {
// key := mcm.makeKey(roleID)
//
// entry := MenuCacheEntry{
// Data: data,
// Timestamp: time.Now().Unix(),
// Version: time.Now().UnixNano(),
// }
//
// // 序列化
// entryData, err := json.Marshal(entry)
// if err != nil {
// mcm.recordError()
// return fmt.Errorf("failed to marshal cache entry: %w", err)
// }
//
// // 存储到BigCache
// err = mcm.cache.Set(key, entryData)
// if err != nil {
// mcm.recordError()
// return fmt.Errorf("failed to set cache: %w", err)
// }
//
// return nil
//}
//
//// Delete 删除缓存数据
//func (mcm *MenuCacheManager) Delete(roleID int32) error {
// key := mcm.makeKey(roleID)
// err := mcm.cache.Delete(key)
// if err != nil && !errors.Is(err, bigcache.ErrEntryNotFound) {
// mcm.recordError()
// return fmt.Errorf("failed to delete cache: %w", err)
// }
// return nil
//}
//
//// makeKey 生成缓存key
//func (mcm *MenuCacheManager) makeKey(roleID int32) string {
// return "menu:role:" + strconv.Itoa(int(roleID))
//}
//
//// GetStats 获取统计信息
//func (mcm *MenuCacheManager) GetStats() *CacheStats {
// mcm.stats.mu.RLock()
// defer mcm.stats.mu.RUnlock()
//
// // 复制统计数据
// stats := *mcm.stats
// return &stats
//}
//
//// GetBigCacheStats 获取BigCache原生统计
//func (mcm *MenuCacheManager) GetBigCacheStats() bigcache.Stats {
// return mcm.cache.Stats()
//}
//
//// recordHit 记录命中
//func (mcm *MenuCacheManager) recordHit() {
// mcm.stats.mu.Lock()
// mcm.stats.Hits++
// mcm.stats.mu.Unlock()
//}
//
//// recordMiss 记录未命中
//func (mcm *MenuCacheManager) recordMiss() {
// mcm.stats.mu.Lock()
// mcm.stats.Misses++
// mcm.stats.mu.Unlock()
//}
//
//// recordError 记录错误
//func (mcm *MenuCacheManager) recordError() {
// mcm.stats.mu.Lock()
// mcm.stats.Errors++
// mcm.stats.mu.Unlock()
//}
//
//// recordRefresh 记录刷新
//func (mcm *MenuCacheManager) recordRefresh() {
// mcm.stats.mu.Lock()
// mcm.stats.RefreshCount++
// mcm.stats.LastRefreshTime = time.Now()
// mcm.stats.mu.Unlock()
//}
//
//// recordAsyncRefreshHit 记录异步刷新命中
//func (mcm *MenuCacheManager) recordAsyncRefreshHit() {
// mcm.stats.mu.Lock()
// mcm.stats.AsyncRefreshHits++
// mcm.stats.mu.Unlock()
//}
//
//// monitor 监控协程
//func (mcm *MenuCacheManager) monitor() {
// for {
// select {
// case <-mcm.monitorTicker.C:
// mcm.logStats()
// case <-mcm.stopCh:
// mcm.monitorTicker.Stop()
// return
// }
// }
//}
//
//// logStats 记录统计信息
//func (mcm *MenuCacheManager) logStats() {
// stats := mcm.GetStats()
// bigCacheStats := mcm.GetBigCacheStats()
//
// log.Printf("MenuCache Stats - Hits: %d, Misses: %d, Errors: %d, HitRate: %.2f%%, "+
// "RefreshCount: %d, AsyncRefreshHits: %d, BigCache Hits: %d, BigCache Misses: %d",
// stats.Hits, stats.Misses, stats.Errors, stats.GetHitRate(),
// stats.RefreshCount, stats.AsyncRefreshHits,
// bigCacheStats.Hits, bigCacheStats.Misses)
//}
//
//// Close 关闭缓存管理器
//func (mcm *MenuCacheManager) Close() error {
// close(mcm.stopCh)
// return mcm.cache.Close()
//}
//
//// CachedMenuService 带BigCache的菜单服务
//type CachedMenuService struct {
// menuService v1.MenuService
// cache *MenuCacheManager
// mu sync.RWMutex
// refreshing map[int32]bool
//}
//
//// NewCachedMenuService 创建带缓存的菜单服务
//func NewCachedMenuService(menuService v1.MenuService, cache *MenuCacheManager) *CachedMenuService {
// return &CachedMenuService{
// menuService: menuService,
// cache: cache,
// refreshing: make(map[int32]bool),
// }
//}
//
//// ListByRoleIDToMap 获取菜单数据带BigCache缓存
//func (cms *CachedMenuService) ListByRoleIDToMap(ctx context.Context, roleID int32) (map[string]*dto.OwnerMenuDto, error) {
// // 尝试从缓存获取
// data, hit, needRefresh := cms.cache.Get(roleID)
// if hit {
// // 如果需要刷新且当前没有在刷新中,启动异步刷新
// if needRefresh && !cms.isRefreshing(roleID) {
// go cms.asyncRefresh(roleID)
// }
// return data, nil
// }
//
// // 缓存未命中,同步获取
// return cms.loadAndCache(ctx, roleID)
//}
//
//// loadAndCache 加载数据并缓存
//func (cms *CachedMenuService) loadAndCache(ctx context.Context, roleID int32) (map[string]*dto.OwnerMenuDto, error) {
// // 防止并发重复加载
// cms.mu.Lock()
// if cms.refreshing[roleID] {
// cms.mu.Unlock()
// // 等待一小段时间后重试缓存
// time.Sleep(time.Millisecond * 5)
// if data, hit, _ := cms.cache.Get(roleID); hit {
// return data, nil
// }
// // 重试失败,继续加载
// cms.mu.Lock()
// }
// cms.refreshing[roleID] = true
// cms.mu.Unlock()
//
// defer func() {
// cms.mu.Lock()
// delete(cms.refreshing, roleID)
// cms.mu.Unlock()
// }()
//
// // 从原始服务获取数据
// data, err := cms.menuService.ListByRoleIDToMap(ctx, roleID)
// if err != nil {
// return nil, err
// }
//
// // 缓存数据
// if cacheErr := cms.cache.Set(roleID, data); cacheErr != nil {
// log.Printf("Failed to cache menu data for roleID %d: %v", roleID, cacheErr)
// // 缓存失败不影响业务逻辑,继续返回数据
// }
//
// return data, nil
//}
//
//// asyncRefresh 异步刷新缓存
//func (cms *CachedMenuService) asyncRefresh(roleID int32) {
// // 检查是否已在刷新中
// if cms.isRefreshing(roleID) {
// return
// }
//
// // 使用背景上下文
// ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
// defer cancel()
//
// _, err := cms.loadAndCache(ctx, roleID)
// if err != nil {
// log.Printf("Async refresh menu cache failed for roleID %d: %v", roleID, err)
// } else {
// cms.cache.recordRefresh()
// cms.cache.recordAsyncRefreshHit()
// }
//}
//
//// isRefreshing 检查是否正在刷新
//func (cms *CachedMenuService) isRefreshing(roleID int32) bool {
// cms.mu.RLock()
// refreshing := cms.refreshing[roleID]
// cms.mu.RUnlock()
// return refreshing
//}
//
//// InvalidateRole 使指定角色缓存失效
//func (cms *CachedMenuService) InvalidateRole(roleID int32) error {
// return cms.cache.Delete(roleID)
//}
//
//// GetCacheStats 获取缓存统计
//func (cms *CachedMenuService) GetCacheStats() *CacheStats {
// return cms.cache.GetStats()
//}
//
//// 全局实例
//var (
// menuCacheManager *MenuCacheManager
// cachedMenuSvc *CachedMenuService
// cacheInitOnce sync.Once
// cacheInitErr error
//)
//
//// InitMenuCache 初始化菜单缓存
//func InitMenuCache(menuService v1.MenuService) error {
// cacheInitOnce.Do(func() {
// // 缓存TTL设置为15分钟
// manager, err := NewMenuCacheManager(15 * time.Minute)
// if err != nil {
// cacheInitErr = fmt.Errorf("failed to initialize menu cache: %w", err)
// return
// }
//
// menuCacheManager = manager
// cachedMenuSvc = NewCachedMenuService(menuService, manager)
//
// log.Println("MenuCache initialized successfully with BigCache")
// })
//
// return cacheInitErr
//}
//
//// GetCachedMenuService 获取缓存菜单服务
//func GetCachedMenuService() *CachedMenuService {
// return cachedMenuSvc
//}
//
//// GetMenuCacheManager 获取缓存管理器
//func GetMenuCacheManager() *MenuCacheManager {
// return menuCacheManager
//}
//
//// CloseMenuCache 关闭菜单缓存
//func CloseMenuCache() error {
// if menuCacheManager != nil {
// return menuCacheManager.Close()
// }
// return nil
//}
//
//// Authorize 修改后的授权中间件
//func Authorize(
// sess session.Manager,
// menuService v1.MenuService,
//) func(http.Handler) http.Handler {
// // 初始化缓存
// if err := InitMenuCache(menuService); err != nil {
// log.Printf("Failed to initialize menu cache: %v", err)
// // 缓存初始化失败,降级到原始服务
// return func(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ctx := r.Context()
// path := r.URL.Path
//
// // 登陆检查
// n := time.Now()
// user, err := sess.GetUser(ctx, know.StoreName)
// if err != nil || user.ID == 0 {
// http.Redirect(w, r, "/", http.StatusFound)
// return
// }
// log.Printf("scs get user: %s", time.Since(n).String())
//
// // 公共路由放行
// if publicRoutes[path] {
// ctx = setUser(ctx, user)
// next.ServeHTTP(w, r.WithContext(ctx))
// return
// }
//
// n1 := time.Now()
// // 权限检查 - 使用原始服务
// menus, err := menuService.ListByRoleIDToMap(ctx, user.RoleID)
// if err != nil || !hasPermission(menus, path) {
// http.Error(w, "Forbidden", http.StatusForbidden)
// return
// }
// log.Printf("listByRoleIDToMap (fallback): %s", time.Since(n1).String())
//
// n2 := time.Now()
// cur := getCurrentMenus(menus, path)
// log.Printf("getCurrentMenus: %s", time.Since(n2).String())
//
// ctx = setUser(ctx, user)
// ctx = setCurMenus(ctx, cur)
// next.ServeHTTP(w, r.WithContext(ctx))
// })
// }
// }
//
// return func(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ctx := r.Context()
// path := r.URL.Path
//
// // 登陆检查
// n := time.Now()
// user, err := sess.GetUser(ctx, know.StoreName)
// if err != nil || user.ID == 0 {
// http.Redirect(w, r, "/", http.StatusFound)
// return
// }
// log.Printf("scs get user: %s", time.Since(n).String())
//
// // 公共路由放行
// if publicRoutes[path] {
// ctx = setUser(ctx, user)
// next.ServeHTTP(w, r.WithContext(ctx))
// return
// }
//
// n1 := time.Now()
// // 权限检查 - 使用BigCache缓存服务
// menus, err := GetCachedMenuService().ListByRoleIDToMap(ctx, user.RoleID)
// if err != nil || !hasPermission(menus, path) {
// http.Error(w, "Forbidden", http.StatusForbidden)
// return
// }
// log.Printf("listByRoleIDToMap (BigCache): %s", time.Since(n1).String())
//
// n2 := time.Now()
// cur := getCurrentMenus(menus, path)
// log.Printf("getCurrentMenus: %s", time.Since(n2).String())
//
// ctx = setUser(ctx, user)
// ctx = setCurMenus(ctx, cur)
// next.ServeHTTP(w, r.WithContext(ctx))
// })
// }
//}
//
//func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool {
// _, ok := menus[path]
// return ok
//}
//
//func getCurrentMenus(data map[string]*dto.OwnerMenuDto, path string) []dto.OwnerMenuDto {
// var res []dto.OwnerMenuDto
//
// menu, ok := data[path]
// if !ok {
// return res
// }
//
// for _, item := range data {
// if menu.IsList {
// if item.ParentID == menu.ID || item.ID == menu.ID {
// res = append(res, *item)
// }
// } else {
// if item.ParentID == menu.ParentID {
// res = append(res, *item)
// }
// }
// }
//
// return res
//}

View File

@@ -0,0 +1,121 @@
package mid
import (
"fmt"
"log"
"net/http"
"time"
"management/internal/erpserver/model/dto"
v1 "management/internal/erpserver/service/v1"
"management/internal/pkg/know"
"management/internal/pkg/session"
"github.com/patrickmn/go-cache"
)
var publicRoutes = map[string]bool{
"/home.html": true,
"/dashboard": true,
"/system/menus": true,
"/upload/img": true,
"/upload/file": true,
"/upload/multi_files": true,
"/pear.json": true,
"/logout": true,
}
// 定义一个全局的go-cache实例
var menuCache *cache.Cache
func init() {
// 初始化go-cache设置默认过期时间为5分钟每10分钟清理一次过期项
menuCache = cache.New(5*time.Minute, 10*time.Minute)
}
func Authorize(
sess session.Manager,
menuService v1.MenuService,
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
path := r.URL.Path
// 登陆检查
user, err := sess.GetUser(ctx, know.StoreName)
if err != nil || user.ID == 0 {
http.Redirect(w, r, "/", http.StatusFound)
return
}
// 公共路由放行
if publicRoutes[path] {
ctx = setUser(ctx, user)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
n1 := time.Now()
// 权限检查
var menus map[string]*dto.OwnerMenuDto
cacheKey := fmt.Sprintf("user_menus:%d", user.RoleID) // 使用用户RoleID作为缓存key
// 尝试从内存缓存中获取菜单数据
if cachedMenus, found := menuCache.Get(cacheKey); found {
menus = cachedMenus.(map[string]*dto.OwnerMenuDto)
log.Printf("listByRoleIDToMap (from cache): %s", time.Since(n1).String())
} else {
// 内存缓存未命中从menuService获取并存入内存缓存
menus, err = menuService.ListByRoleIDToMap(ctx, user.RoleID)
if err != nil {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
menuCache.Set(cacheKey, menus, cache.DefaultExpiration) // 使用默认过期时间
log.Printf("listByRoleIDToMap (from service, then cached): %s", time.Since(n1).String())
}
if !hasPermission(menus, path) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
cur := getCurrentMenus(menus, path)
ctx = setUser(ctx, user)
ctx = setCurMenus(ctx, cur)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool {
_, ok := menus[path]
return ok
}
func getCurrentMenus(data map[string]*dto.OwnerMenuDto, path string) []dto.OwnerMenuDto {
var res []dto.OwnerMenuDto
menu, ok := data[path]
if !ok {
return res
}
for _, item := range data {
if menu.IsList {
if item.ParentID == menu.ID || item.ID == menu.ID {
res = append(res, *item)
}
} else {
if item.ParentID == menu.ParentID {
res = append(res, *item)
}
}
}
return res
}

22
internal/pkg/mid/csrf.go Normal file
View File

@@ -0,0 +1,22 @@
package mid
import (
"fmt"
"net/http"
"github.com/justinas/nosurf"
)
func NoSurf(next http.Handler) http.Handler {
return nosurf.New(next)
}
func NoSurfContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := nosurf.Token(r)
ctx := setCsrfToken(r.Context(), token)
ctx = setHtmlCsrfToken(ctx, fmt.Sprintf(`<input type="hidden" name="csrf_token" value="%s" />`, token))
next.ServeHTTP(w, r.WithContext(ctx))
})
}

75
internal/pkg/mid/mid.go Normal file
View File

@@ -0,0 +1,75 @@
package mid
import (
"context"
"management/internal/erpserver/model/dto"
"github.com/a-h/templ"
)
type userKey struct{}
func setUser(ctx context.Context, usr dto.AuthorizeUser) context.Context {
return context.WithValue(ctx, userKey{}, usr)
}
// GetUser returns the user from the context.
func GetUser(ctx context.Context) dto.AuthorizeUser {
v, ok := ctx.Value(userKey{}).(dto.AuthorizeUser)
if !ok {
return dto.AuthorizeUser{}
}
return v
}
type menuKey struct{}
func setCurMenus(ctx context.Context, ms []dto.OwnerMenuDto) context.Context {
return context.WithValue(ctx, menuKey{}, ms)
}
func GetCurMenus(ctx context.Context) []dto.OwnerMenuDto {
v, ok := ctx.Value(menuKey{}).([]dto.OwnerMenuDto)
if !ok {
return []dto.OwnerMenuDto{}
}
return v
}
type NoSurfToken struct {
Token string
HtmlToken string
}
type csrfKey struct{}
func setCsrfToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, csrfKey{}, token)
}
func GetCsrfToken(ctx context.Context) string {
v, ok := ctx.Value(csrfKey{}).(string)
if !ok {
return ""
}
return v
}
type htmlCsrfKey struct{}
func setHtmlCsrfToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, htmlCsrfKey{}, templ.Raw(token))
}
func GetHtmlCsrfToken(ctx context.Context) templ.Component {
v, ok := ctx.Value(htmlCsrfKey{}).(templ.Component)
if !ok {
return templ.Raw("")
}
return v
}

View File

@@ -1,4 +1,4 @@
package middleware
package mid
import (
"net/http"

View File

@@ -1,58 +0,0 @@
package middleware
import (
"context"
"errors"
"net/http"
"time"
systemmodel "management/internal/erpserver/model/system"
v1 "management/internal/erpserver/service/v1"
"management/internal/pkg/know"
"management/internal/pkg/session"
"github.com/drhin/logger"
"go.uber.org/zap"
)
func Audit(sess session.Manager, auditLogService v1.AuditLogService, log *logger.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 提前获取用户信息(同步操作)
user, err := sess.GetUser(r.Context(), know.StoreName)
if err != nil {
log.Error("获取用户会话失败", err)
next.ServeHTTP(w, r) // 继续处理请求
return
}
defer func() {
go func() {
if user.ID == 0 {
log.Error("用户信息为空", errors.New("scs get user is empty"))
return
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
al := systemmodel.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now())
if err := auditLogService.Create(ctx, al); err != nil {
log.Error(err.Error(), err,
zap.Int32("user_id", user.ID),
zap.String("user", user.Email),
zap.String("ip", al.Ip),
zap.String("os", al.Os),
zap.String("method", al.Method),
zap.String("path", al.Url),
)
}
}()
}()
next.ServeHTTP(w, r)
})
}
}

View File

@@ -1,60 +0,0 @@
package middleware
import (
"net/http"
"management/internal/erpserver/model/dto"
v1 "management/internal/erpserver/service/v1"
"management/internal/pkg/know"
"management/internal/pkg/session"
)
var publicRoutes = map[string]bool{
"/home.html": true,
"/dashboard": true,
"/system/menus": true,
"/upload/img": true,
"/upload/file": true,
"/upload/multi_files": true,
"/pear.json": true,
"/logout": true,
}
func Authorize(
sess session.Manager,
menuService v1.MenuService,
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
path := r.URL.Path
// 登陆检查
user, err := sess.GetUser(ctx, know.StoreName)
if err != nil || user.ID == 0 {
http.Redirect(w, r, "/", http.StatusFound)
return
}
// 公共路由放行
if publicRoutes[path] {
next.ServeHTTP(w, r)
return
}
// 权限检查
menus, err := menuService.ListByRoleIDToMap(ctx, user.RoleID)
if err != nil || !hasPermission(menus, path) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}
func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool {
_, ok := menus[path]
return ok
}

View File

@@ -1,11 +0,0 @@
package middleware
import (
"net/http"
"github.com/justinas/nosurf"
)
func NoSurf(next http.Handler) http.Handler {
return nosurf.New(next)
}

View File

@@ -61,7 +61,7 @@ func (r *render) setDefaultData(req *http.Request, data map[string]any) map[stri
ctx := req.Context()
authUser, err := r.session.GetUser(ctx, know.StoreName)
if err != nil || authUser == nil {
if err != nil || authUser.ID == 0 {
data["IsAuthenticated"] = false
} else {
data["IsAuthenticated"] = true

View File

@@ -1,83 +1,86 @@
package session
// import (
// "context"
// "time"
import (
"context"
"errors"
"time"
// "management/internal/pkg/redis"
// )
"github.com/redis/go-redis/v9"
)
// var (
// storePrefix = "scs:session:"
// ctx = context.Background()
// DefaultRedisStore = newRedisStore()
// )
// 为所有 Redis 操作定义一个合理的超时时间。
// 这个值应该根据你的服务SLA和网络状况来定通常在50-500毫秒之间是比较合理的。
// 最佳实践:这个值应该来自配置文件,而不是硬编码。
const redisTimeout = 200 * time.Millisecond
// type redisStore struct{}
// RedisStore 表示一个使用 go-redis/v9 客户端的 scs.Store 实现。
type RedisStore struct {
// 内嵌 go-redis 客户端
client *redis.Client
}
// func newRedisStore() *redisStore {
// return &redisStore{}
// }
// NewRedisStore 是 RedisStore 的构造函数。
func NewRedisStore(client *redis.Client) *RedisStore {
return &RedisStore{
client: client,
}
}
// // Delete should remove the session token and corresponding data from the
// // session store. If the token does not exist then Delete should be a no-op
// // and return nil (not an error).
// func (s *redisStore) Delete(token string) error {
// return redis.Del(ctx, storePrefix+token)
// }
// Find 方法根据 session token 从 Redis 中查找 session 数据。
// 如果 token 不存在或已过期exists 返回 false。
func (s *RedisStore) Find(token string) ([]byte, bool, error) {
// ✅ 最佳实践: 为数据库操作创建带超时的上下文
ctx, cancel := context.WithTimeout(context.Background(), redisTimeout)
// ✅ 必须: 无论函数如何返回,都调用 cancel() 来释放上下文资源
defer cancel()
// // Find should return the data for a session token from the store. If the
// // session token is not found or is expired, the found return value should
// // be false (and the err return value should be nil). Similarly, tampered
// // or malformed tokens should result in a found return value of false and a
// // nil err value. The err return value should be used for system errors only.
// func (s *redisStore) Find(token string) (b []byte, found bool, err error) {
// val, err := redis.GetBytes(ctx, storePrefix+token)
// if err != nil {
// return nil, false, err
// } else {
// return val, true, nil
// }
// }
// 使用 go-redis 的 Get 方法
data, err := s.client.Get(ctx, token).Bytes()
if err != nil {
// 如果 key 不存在go-redis 会返回 redis.Nil 错误
if errors.Is(err, redis.Nil) {
return nil, false, nil
}
return nil, false, err
}
// // Commit should add the session token and data to the store, with the given
// // expiry time. If the session token already exists, then the data and
// // expiry time should be overwritten.
// func (s *redisStore) Commit(token string, b []byte, expiry time.Time) error {
// // TODO: 这边可以调整时间
// exp, err := time.ParseInLocation(time.DateTime, time.Now().Format("2006-01-02")+" 23:59:59", time.Local)
// if err != nil {
// return err
// }
return data, true, nil
}
// t := time.Now()
// expired := exp.Sub(t)
// return redis.Set(ctx, storePrefix+token, b, expired)
// }
// Commit 方法将 session 数据和过期时间存入 Redis。
// 如果 token 已存在,则更新其数据和过期时间。
func (s *RedisStore) Commit(token string, b []byte, expiry time.Time) error {
// ✅ 最佳实践: 为数据库操作创建带超时的上下文
ctx, cancel := context.WithTimeout(context.Background(), redisTimeout)
// ✅ 必须: 无论函数如何返回,都调用 cancel() 来释放上下文资源
defer cancel()
// // All should return a map containing data for all active sessions (i.e.
// // sessions which have not expired). The map key should be the session
// // token and the map value should be the session data. If no active
// // sessions exist this should return an empty (not nil) map.
// func (s *redisStore) All() (map[string][]byte, error) {
// sessions := make(map[string][]byte)
// 计算 Redis 的 TTL (Time To Live)
// time.Until(expiry) 会计算出当前时间到 expiry 之间的时间差
ttl := time.Until(expiry)
// iter := redis.Scan(ctx, 0, storePrefix+"*", 0).Iterator()
// for iter.Next(ctx) {
// key := iter.Val()
// token := key[len(storePrefix):]
// data, exists, err := s.Find(token)
// if err != nil {
// return nil, err
// }
// 使用 go-redis 的 Set 方法,并设置过期时间
// 如果 expiry 时间已经过去ttl 会是负数Redis 会立即删除这个 key这正是我们期望的行为。
err := s.client.Set(ctx, token, b, ttl).Err()
if err != nil {
return err
}
// if exists {
// sessions[token] = data
// }
// }
// if err := iter.Err(); err != nil {
// return nil, err
// }
return nil
}
// return sessions, nil
// }
// Delete 方法根据 session token 从 Redis 中删除 session 数据。
func (s *RedisStore) Delete(token string) error {
// ✅ 最佳实践: 为数据库操作创建带超时的上下文
ctx, cancel := context.WithTimeout(context.Background(), redisTimeout)
// ✅ 必须: 无论函数如何返回,都调用 cancel() 来释放上下文资源
defer cancel()
// 使用 go-redis 的 Del 方法
err := s.client.Del(ctx, token).Err()
if err != nil {
return err
}
return nil
}

View File

@@ -10,9 +10,8 @@ import (
"management/internal/erpserver/model/dto"
"management/internal/pkg/config"
"github.com/alexedwards/scs/postgresstore"
"github.com/alexedwards/scs/v2"
"gorm.io/gorm"
"github.com/redis/go-redis/v9"
)
var ErrNoSession = errors.New("session user not found")
@@ -20,8 +19,8 @@ var ErrNoSession = errors.New("session user not found")
// Manager 抽象核心会话操作
type Manager interface {
Load(next http.Handler) http.Handler
GetUser(ctx context.Context, key string) (*dto.AuthorizeUser, error)
PutUser(ctx context.Context, key string, user *dto.AuthorizeUser) error
GetUser(ctx context.Context, key string) (dto.AuthorizeUser, error)
PutUser(ctx context.Context, key string, user dto.AuthorizeUser) error
RenewToken(ctx context.Context) error
Destroy(ctx context.Context) error
}
@@ -30,7 +29,7 @@ type SCSSession struct {
manager *scs.SessionManager
}
func NewSCSManager(db *gorm.DB, config *config.Config) (Manager, error) {
func NewSCSManager(client *redis.Client, conf *config.Config) (Manager, error) {
sessionManager := scs.New()
sessionManager.Lifetime = 24 * time.Hour
sessionManager.IdleTimeout = 2 * time.Hour
@@ -38,21 +37,21 @@ func NewSCSManager(db *gorm.DB, config *config.Config) (Manager, error) {
sessionManager.Cookie.HttpOnly = true
sessionManager.Cookie.Persist = true
sessionManager.Cookie.SameSite = http.SameSiteStrictMode
sessionManager.Cookie.Secure = config.App.Prod
sessionManager.Cookie.Secure = conf.App.Prod
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
//sqlDB, err := db.DB()
//if err != nil {
// return nil, err
//}
// postgres
// github.com/alexedwards/scs/postgresstore
sessionManager.Store = postgresstore.New(sqlDB)
// sessionManager.Store = postgresstore.New(sqlDB)
// pgx
// github.com/alexedwards/scs/pgxstore
// sessionManager.Store = pgxstore.New(pool)
// redis
// sessionManager.Store = newRedisStore()
sessionManager.Store = NewRedisStore(client)
return &SCSSession{manager: sessionManager}, nil
}
@@ -60,21 +59,21 @@ func (s *SCSSession) Load(next http.Handler) http.Handler {
return s.manager.LoadAndSave(next)
}
func (s *SCSSession) GetUser(ctx context.Context, key string) (*dto.AuthorizeUser, error) {
func (s *SCSSession) GetUser(ctx context.Context, key string) (dto.AuthorizeUser, error) {
data, ok := s.manager.Get(ctx, key).([]byte)
if !ok || len(data) == 0 {
return nil, ErrNoSession
return dto.AuthorizeUser{}, ErrNoSession
}
var user dto.AuthorizeUser
if err := json.Unmarshal(data, &user); err != nil {
return nil, err
return dto.AuthorizeUser{}, err
}
return &user, nil
return user, nil
}
func (s *SCSSession) PutUser(ctx context.Context, key string, user *dto.AuthorizeUser) error {
data, err := json.Marshal(user)
func (s *SCSSession) PutUser(ctx context.Context, key string, user dto.AuthorizeUser) error {
data, err := json.Marshal(&user)
if err != nil {
return err
}