230 lines
5.7 KiB
Go
230 lines
5.7 KiB
Go
package mid
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"net/http"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"management/internal/erpserver/model/dto"
|
||
v1 "management/internal/erpserver/service/v1"
|
||
"management/internal/pkg/token"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/patrickmn/go-cache"
|
||
"golang.org/x/sync/singleflight"
|
||
)
|
||
|
||
var publicRoutes = map[string]bool{
|
||
"/dashboard": true,
|
||
"/system/menus": true,
|
||
"/upload/img": true,
|
||
"/upload/file": true,
|
||
"/upload/multi_files": true,
|
||
"/logout": true,
|
||
"/user_info": true,
|
||
}
|
||
|
||
// 分片缓存配置
|
||
const cacheShards = 64 // 根据CPU核心数调整(建议核心数*4)
|
||
var (
|
||
menuCacheShards [cacheShards]*cache.Cache
|
||
shardMutexes [cacheShards]*sync.Mutex
|
||
flightGroup singleflight.Group
|
||
)
|
||
|
||
func init() {
|
||
// 初始化分片缓存
|
||
for i := 0; i < cacheShards; i++ {
|
||
menuCacheShards[i] = cache.New(5*time.Minute, 10*time.Minute)
|
||
shardMutexes[i] = &sync.Mutex{}
|
||
}
|
||
}
|
||
|
||
// 获取分片索引(基于角色ID)
|
||
func getShardIndex(roleID uint) int {
|
||
return int(roleID) % cacheShards
|
||
}
|
||
|
||
// 缓存过期时间分级
|
||
func getCacheExpiration(roleID uint) time.Duration {
|
||
// 这里可以添加业务逻辑区分高频/低频角色
|
||
// 示例:高频角色缓存30分钟,低频10分钟
|
||
if isHighFrequencyRole(roleID) {
|
||
return 30 * time.Minute
|
||
}
|
||
return 10 * time.Minute
|
||
}
|
||
|
||
// 判断是否为高频角色(示例实现)
|
||
func isHighFrequencyRole(roleID uint) bool {
|
||
// 实际项目中,这里可以根据角色ID查询配置或历史访问频率
|
||
return roleID < 100 // 假设ID小于100的角色是高频角色
|
||
}
|
||
|
||
// WarmUpMenuCache 缓存预热函数(在服务启动时调用)
|
||
func WarmUpMenuCache(menuService v1.MenuService) {
|
||
ctx := context.Background()
|
||
// 获取所有角色ID(这里需要实现getAllRoleIDs)
|
||
roleIDs := getAllRoleIDs()
|
||
|
||
log.Printf("Starting cache warm-up for %d roles", len(roleIDs))
|
||
|
||
// 并发预热
|
||
var wg sync.WaitGroup
|
||
wg.Add(len(roleIDs))
|
||
|
||
for _, roleID := range roleIDs {
|
||
go func(rid uint) {
|
||
defer wg.Done()
|
||
|
||
shardIndex := getShardIndex(rid)
|
||
cacheKey := fmt.Sprintf("user_menus:%d", rid)
|
||
shard := menuCacheShards[shardIndex]
|
||
|
||
// 预热数据
|
||
if _, found := shard.Get(cacheKey); !found {
|
||
menus, err := menuService.ListByRoleIDToMap(ctx, int32(rid))
|
||
if err == nil {
|
||
shard.Set(cacheKey, menus, getCacheExpiration(rid))
|
||
}
|
||
}
|
||
}(roleID)
|
||
}
|
||
|
||
wg.Wait()
|
||
log.Println("Menu cache warm-up completed")
|
||
}
|
||
|
||
// 获取所有角色ID(需要实现)
|
||
func getAllRoleIDs() []uint {
|
||
// 实际项目中需要从数据库获取所有角色ID
|
||
// 这里返回一个示例ID列表
|
||
return []uint{1, 2, 3, 4, 5}
|
||
}
|
||
|
||
const (
|
||
authorizationHeaderKey = "authorization"
|
||
authorizationTypeBearer = "bearer"
|
||
authorizationPayloadKey = "authorization_payload"
|
||
)
|
||
|
||
func Authorize(
|
||
tokenMaker token.Maker,
|
||
menuService v1.MenuService,
|
||
) gin.HandlerFunc {
|
||
return func(ctx *gin.Context) {
|
||
|
||
authorizationHeader := ctx.GetHeader(authorizationHeaderKey)
|
||
|
||
if len(authorizationHeader) == 0 {
|
||
err := errors.New("authorization header is not provided")
|
||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||
return
|
||
}
|
||
|
||
fields := strings.Fields(authorizationHeader)
|
||
if len(fields) < 2 {
|
||
err := errors.New("invalid authorization header format")
|
||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||
return
|
||
}
|
||
|
||
authorizationType := strings.ToLower(fields[0])
|
||
if authorizationType != authorizationTypeBearer {
|
||
err := fmt.Errorf("unsupported authorization type %s", authorizationType)
|
||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||
return
|
||
}
|
||
|
||
accessToken := fields[1]
|
||
payload, err := tokenMaker.VerifyToken(accessToken, token.TypeAccessToken)
|
||
if err != nil {
|
||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||
return
|
||
}
|
||
|
||
setUser(ctx, payload)
|
||
ctx.Next()
|
||
return
|
||
|
||
// 用户校验完毕, 开始校验权限
|
||
|
||
path := ctx.Request.URL.Path
|
||
|
||
// 公共路由放行
|
||
if publicRoutes[path] {
|
||
setUser(ctx, payload)
|
||
ctx.Next()
|
||
return
|
||
}
|
||
|
||
// 权限检查
|
||
var menus map[string]*dto.OwnerMenuDto
|
||
cacheKey := fmt.Sprintf("user_menus:%d", payload.RoleID)
|
||
shardIndex := getShardIndex(uint(payload.RoleID))
|
||
shard := menuCacheShards[shardIndex]
|
||
mutex := shardMutexes[shardIndex]
|
||
|
||
// 1. 尝试无锁读缓存
|
||
if cachedMenus, found := shard.Get(cacheKey); found {
|
||
menus = cachedMenus.(map[string]*dto.OwnerMenuDto)
|
||
} else {
|
||
// 2. 单飞机制防止缓存击穿
|
||
menusI, err, _ := flightGroup.Do(cacheKey, func() (interface{}, error) {
|
||
// 3. 双检锁机制
|
||
if cached, found := shard.Get(cacheKey); found {
|
||
return cached, nil
|
||
}
|
||
|
||
// 4. 查询数据库获取菜单数据
|
||
maps, err := menuService.ListByRoleIDToMap(ctx, int32(payload.RoleID))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 5. 写入缓存(加锁避免重复写入)
|
||
mutex.Lock()
|
||
defer mutex.Unlock()
|
||
|
||
// 6. 再次检查(避免其他goroutine已经写入)
|
||
if cached, found := shard.Get(cacheKey); found {
|
||
return cached, nil
|
||
}
|
||
|
||
// 7. 设置分级过期时间
|
||
expiration := getCacheExpiration(uint(payload.RoleID))
|
||
shard.Set(cacheKey, maps, expiration)
|
||
return maps, nil
|
||
})
|
||
|
||
if err != nil {
|
||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||
return
|
||
}
|
||
menus = menusI.(map[string]*dto.OwnerMenuDto)
|
||
}
|
||
|
||
if !hasPermission(menus, path) {
|
||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||
return
|
||
}
|
||
|
||
setUser(ctx, payload)
|
||
ctx.Next()
|
||
}
|
||
}
|
||
|
||
func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool {
|
||
_, ok := menus[path]
|
||
return ok
|
||
}
|
||
|
||
func errorResponse(err error) gin.H {
|
||
return gin.H{"error": err.Error()}
|
||
}
|