247 lines
6.2 KiB
Go
247 lines
6.2 KiB
Go
package mid
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"log"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"management/internal/erpserver/model/dto"
|
||
v1 "management/internal/erpserver/service/v1"
|
||
"management/internal/pkg/know"
|
||
"management/internal/pkg/session"
|
||
|
||
"github.com/patrickmn/go-cache"
|
||
"golang.org/x/sync/singleflight"
|
||
)
|
||
|
||
//// 高性能JSON库(全局初始化)
|
||
//var json = jsoniter.ConfigFastest
|
||
//
|
||
//// 使用jsoniter优化菜单结构体序列化
|
||
//func init() {
|
||
// jsoniter.RegisterTypeEncoderFunc("dto.OwnerMenuDto",
|
||
// func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
|
||
// m := (*dto.OwnerMenuDto)(ptr)
|
||
// stream.WriteObjectStart()
|
||
// stream.WriteObjectField("id")
|
||
// stream.WriteUint(uint(m.ID))
|
||
// stream.WriteMore()
|
||
// stream.WriteObjectField("url")
|
||
// stream.WriteString(m.Url)
|
||
// stream.WriteMore()
|
||
// stream.WriteObjectField("parentId")
|
||
// stream.WriteUint(uint(m.ParentID))
|
||
// stream.WriteMore()
|
||
// stream.WriteObjectField("isList")
|
||
// stream.WriteBool(m.IsList)
|
||
// stream.WriteObjectEnd()
|
||
// }, nil)
|
||
//}
|
||
|
||
var publicRoutes = map[string]bool{
|
||
"/home.html": true,
|
||
"/dashboard": true,
|
||
"/system/menus": true,
|
||
"/upload/img": true,
|
||
"/upload/file": true,
|
||
"/upload/multi_files": true,
|
||
"/system/pear.json": true,
|
||
"/logout": true,
|
||
}
|
||
|
||
// 分片缓存配置
|
||
const cacheShards = 64 // 根据CPU核心数调整(建议核心数*4)
|
||
var (
|
||
menuCacheShards [cacheShards]*cache.Cache
|
||
shardMutexes [cacheShards]*sync.Mutex
|
||
flightGroup singleflight.Group
|
||
)
|
||
|
||
func init() {
|
||
// 初始化分片缓存
|
||
for i := 0; i < cacheShards; i++ {
|
||
menuCacheShards[i] = cache.New(5*time.Minute, 10*time.Minute)
|
||
shardMutexes[i] = &sync.Mutex{}
|
||
}
|
||
}
|
||
|
||
// 获取分片索引(基于角色ID)
|
||
func getShardIndex(roleID uint) int {
|
||
return int(roleID) % cacheShards
|
||
}
|
||
|
||
// 缓存过期时间分级
|
||
func getCacheExpiration(roleID uint) time.Duration {
|
||
// 这里可以添加业务逻辑区分高频/低频角色
|
||
// 示例:高频角色缓存30分钟,低频10分钟
|
||
if isHighFrequencyRole(roleID) {
|
||
return 30 * time.Minute
|
||
}
|
||
return 10 * time.Minute
|
||
}
|
||
|
||
// 判断是否为高频角色(示例实现)
|
||
func isHighFrequencyRole(roleID uint) bool {
|
||
// 实际项目中,这里可以根据角色ID查询配置或历史访问频率
|
||
return roleID < 100 // 假设ID小于100的角色是高频角色
|
||
}
|
||
|
||
// WarmUpMenuCache 缓存预热函数(在服务启动时调用)
|
||
func WarmUpMenuCache(menuService v1.MenuService) {
|
||
ctx := context.Background()
|
||
// 获取所有角色ID(这里需要实现getAllRoleIDs)
|
||
roleIDs := getAllRoleIDs()
|
||
|
||
log.Printf("Starting cache warm-up for %d roles", len(roleIDs))
|
||
|
||
// 并发预热
|
||
var wg sync.WaitGroup
|
||
wg.Add(len(roleIDs))
|
||
|
||
for _, roleID := range roleIDs {
|
||
go func(rid uint) {
|
||
defer wg.Done()
|
||
|
||
shardIndex := getShardIndex(rid)
|
||
cacheKey := fmt.Sprintf("user_menus:%d", rid)
|
||
shard := menuCacheShards[shardIndex]
|
||
|
||
// 预热数据
|
||
if _, found := shard.Get(cacheKey); !found {
|
||
menus, err := menuService.ListByRoleIDToMap(ctx, int32(rid))
|
||
if err == nil {
|
||
shard.Set(cacheKey, menus, getCacheExpiration(rid))
|
||
}
|
||
}
|
||
}(roleID)
|
||
}
|
||
|
||
wg.Wait()
|
||
log.Println("Menu cache warm-up completed")
|
||
}
|
||
|
||
// 获取所有角色ID(需要实现)
|
||
func getAllRoleIDs() []uint {
|
||
// 实际项目中需要从数据库获取所有角色ID
|
||
// 这里返回一个示例ID列表
|
||
return []uint{1, 2, 3, 4, 5}
|
||
}
|
||
|
||
func Authorize(
|
||
sess session.Manager,
|
||
menuService v1.MenuService,
|
||
) func(http.Handler) http.Handler {
|
||
return func(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
ctx := r.Context()
|
||
path := r.URL.Path
|
||
|
||
// 登陆检查
|
||
user, err := sess.GetUser(ctx, know.StoreName)
|
||
if err != nil || user.ID == 0 {
|
||
http.Redirect(w, r, "/", http.StatusFound)
|
||
return
|
||
}
|
||
|
||
// 公共路由放行
|
||
if publicRoutes[path] {
|
||
ctx = setUser(ctx, user)
|
||
next.ServeHTTP(w, r.WithContext(ctx))
|
||
return
|
||
}
|
||
|
||
n1 := time.Now()
|
||
// 权限检查
|
||
var menus map[string]*dto.OwnerMenuDto
|
||
cacheKey := fmt.Sprintf("user_menus:%d", user.RoleID)
|
||
shardIndex := getShardIndex(uint(user.RoleID))
|
||
shard := menuCacheShards[shardIndex]
|
||
mutex := shardMutexes[shardIndex]
|
||
|
||
// 1. 尝试无锁读缓存
|
||
if cachedMenus, found := shard.Get(cacheKey); found {
|
||
menus = cachedMenus.(map[string]*dto.OwnerMenuDto)
|
||
log.Printf("listByRoleIDToMap (from cache): %s", time.Since(n1).String())
|
||
} else {
|
||
// 2. 单飞机制防止缓存击穿
|
||
menusI, err, _ := flightGroup.Do(cacheKey, func() (interface{}, error) {
|
||
// 3. 双检锁机制
|
||
if cached, found := shard.Get(cacheKey); found {
|
||
return cached, nil
|
||
}
|
||
|
||
// 4. 查询数据库获取菜单数据
|
||
maps, err := menuService.ListByRoleIDToMap(ctx, user.RoleID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 5. 写入缓存(加锁避免重复写入)
|
||
mutex.Lock()
|
||
defer mutex.Unlock()
|
||
|
||
// 6. 再次检查(避免其他goroutine已经写入)
|
||
if cached, found := shard.Get(cacheKey); found {
|
||
return cached, nil
|
||
}
|
||
|
||
// 7. 设置分级过期时间
|
||
expiration := getCacheExpiration(uint(user.RoleID))
|
||
shard.Set(cacheKey, maps, expiration)
|
||
return maps, nil
|
||
})
|
||
|
||
if err != nil {
|
||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||
return
|
||
}
|
||
menus = menusI.(map[string]*dto.OwnerMenuDto)
|
||
log.Printf("listByRoleIDToMap (from DB, then cached): %s", time.Since(n1).String())
|
||
}
|
||
|
||
if !hasPermission(menus, path) {
|
||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||
return
|
||
}
|
||
|
||
cur := getCurrentMenus(menus, path)
|
||
|
||
ctx = setUser(ctx, user)
|
||
ctx = setCurMenus(ctx, cur)
|
||
|
||
next.ServeHTTP(w, r.WithContext(ctx))
|
||
})
|
||
}
|
||
}
|
||
|
||
func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool {
|
||
_, ok := menus[path]
|
||
return ok
|
||
}
|
||
|
||
func getCurrentMenus(data map[string]*dto.OwnerMenuDto, path string) []dto.OwnerMenuDto {
|
||
var res []dto.OwnerMenuDto
|
||
|
||
menu, ok := data[path]
|
||
if !ok {
|
||
return res
|
||
}
|
||
|
||
for _, item := range data {
|
||
if menu.IsList {
|
||
if item.ParentID == menu.ID || item.ID == menu.ID {
|
||
res = append(res, *item)
|
||
}
|
||
} else {
|
||
if item.ParentID == menu.ParentID {
|
||
res = append(res, *item)
|
||
}
|
||
}
|
||
}
|
||
|
||
return res
|
||
}
|