package mid import ( "context" "errors" "fmt" "log" "net/http" "strings" "sync" "time" "management/internal/erpserver/model/dto" v1 "management/internal/erpserver/service/v1" "management/internal/pkg/token" "github.com/gin-gonic/gin" "github.com/patrickmn/go-cache" "golang.org/x/sync/singleflight" ) var publicRoutes = map[string]bool{ "/dashboard": true, "/system/menus": true, "/upload/img": true, "/upload/file": true, "/upload/multi_files": true, "/logout": true, "/user_info": true, } // 分片缓存配置 const cacheShards = 64 // 根据CPU核心数调整(建议核心数*4) var ( menuCacheShards [cacheShards]*cache.Cache shardMutexes [cacheShards]*sync.Mutex flightGroup singleflight.Group ) func init() { // 初始化分片缓存 for i := 0; i < cacheShards; i++ { menuCacheShards[i] = cache.New(5*time.Minute, 10*time.Minute) shardMutexes[i] = &sync.Mutex{} } } // 获取分片索引(基于角色ID) func getShardIndex(roleID uint) int { return int(roleID) % cacheShards } // 缓存过期时间分级 func getCacheExpiration(roleID uint) time.Duration { // 这里可以添加业务逻辑区分高频/低频角色 // 示例:高频角色缓存30分钟,低频10分钟 if isHighFrequencyRole(roleID) { return 30 * time.Minute } return 10 * time.Minute } // 判断是否为高频角色(示例实现) func isHighFrequencyRole(roleID uint) bool { // 实际项目中,这里可以根据角色ID查询配置或历史访问频率 return roleID < 100 // 假设ID小于100的角色是高频角色 } // WarmUpMenuCache 缓存预热函数(在服务启动时调用) func WarmUpMenuCache(menuService v1.MenuService) { ctx := context.Background() // 获取所有角色ID(这里需要实现getAllRoleIDs) roleIDs := getAllRoleIDs() log.Printf("Starting cache warm-up for %d roles", len(roleIDs)) // 并发预热 var wg sync.WaitGroup wg.Add(len(roleIDs)) for _, roleID := range roleIDs { go func(rid uint) { defer wg.Done() shardIndex := getShardIndex(rid) cacheKey := fmt.Sprintf("user_menus:%d", rid) shard := menuCacheShards[shardIndex] // 预热数据 if _, found := shard.Get(cacheKey); !found { menus, err := menuService.ListByRoleIDToMap(ctx, int32(rid)) if err == nil { shard.Set(cacheKey, menus, getCacheExpiration(rid)) } } }(roleID) } wg.Wait() log.Println("Menu cache warm-up completed") } // 获取所有角色ID(需要实现) func getAllRoleIDs() []uint { // 实际项目中需要从数据库获取所有角色ID // 这里返回一个示例ID列表 return []uint{1, 2, 3, 4, 5} } const ( authorizationHeaderKey = "authorization" authorizationTypeBearer = "bearer" authorizationPayloadKey = "authorization_payload" ) func Authorize( tokenMaker token.Maker, menuService v1.MenuService, ) gin.HandlerFunc { return func(ctx *gin.Context) { authorizationHeader := ctx.GetHeader(authorizationHeaderKey) if len(authorizationHeader) == 0 { err := errors.New("authorization header is not provided") ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) return } fields := strings.Fields(authorizationHeader) if len(fields) < 2 { err := errors.New("invalid authorization header format") ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) return } authorizationType := strings.ToLower(fields[0]) if authorizationType != authorizationTypeBearer { err := fmt.Errorf("unsupported authorization type %s", authorizationType) ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) return } accessToken := fields[1] payload, err := tokenMaker.VerifyToken(accessToken, token.TypeAccessToken) if err != nil { ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) return } setUser(ctx, payload) ctx.Next() return // 用户校验完毕, 开始校验权限 path := ctx.Request.URL.Path // 公共路由放行 if publicRoutes[path] { setUser(ctx, payload) ctx.Next() return } // 权限检查 var menus map[string]*dto.OwnerMenuDto cacheKey := fmt.Sprintf("user_menus:%d", payload.RoleID) shardIndex := getShardIndex(uint(payload.RoleID)) shard := menuCacheShards[shardIndex] mutex := shardMutexes[shardIndex] // 1. 尝试无锁读缓存 if cachedMenus, found := shard.Get(cacheKey); found { menus = cachedMenus.(map[string]*dto.OwnerMenuDto) } else { // 2. 单飞机制防止缓存击穿 menusI, err, _ := flightGroup.Do(cacheKey, func() (interface{}, error) { // 3. 双检锁机制 if cached, found := shard.Get(cacheKey); found { return cached, nil } // 4. 查询数据库获取菜单数据 maps, err := menuService.ListByRoleIDToMap(ctx, int32(payload.RoleID)) if err != nil { return nil, err } // 5. 写入缓存(加锁避免重复写入) mutex.Lock() defer mutex.Unlock() // 6. 再次检查(避免其他goroutine已经写入) if cached, found := shard.Get(cacheKey); found { return cached, nil } // 7. 设置分级过期时间 expiration := getCacheExpiration(uint(payload.RoleID)) shard.Set(cacheKey, maps, expiration) return maps, nil }) if err != nil { ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) return } menus = menusI.(map[string]*dto.OwnerMenuDto) } if !hasPermission(menus, path) { ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) return } setUser(ctx, payload) ctx.Next() } } func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool { _, ok := menus[path] return ok } func errorResponse(err error) gin.H { return gin.H{"error": err.Error()} }