v1
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package convertor
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
@@ -23,13 +22,8 @@ func Bool(value string, defaultValue bool) bool {
|
||||
return b
|
||||
}
|
||||
|
||||
func QueryInt[T int | int16 | int32 | int64](vars url.Values, key string, defaultValue T) T {
|
||||
v := vars.Get(key)
|
||||
if len(v) == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
i, err := strconv.Atoi(v)
|
||||
func QueryInt[T int | int16 | int32 | int64](value string, defaultValue T) T {
|
||||
i, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
26
internal/pkg/gin/gu/cors.go
Normal file
26
internal/pkg/gin/gu/cors.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package gu
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Cors() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
method := c.Request.Method
|
||||
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token, Authorization, Token")
|
||||
c.Header("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS")
|
||||
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
//放行所有OPTIONS方法
|
||||
if method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
}
|
||||
// 处理请求
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
49
internal/pkg/gin/gu/response.go
Normal file
49
internal/pkg/gin/gu/response.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package gu
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
||||
type PageData struct {
|
||||
Total int64 `json:"total"`
|
||||
PageID int `json:"page_id"`
|
||||
PageSize int `json:"page_size"`
|
||||
Result any `json:"result"`
|
||||
}
|
||||
|
||||
func NewPageData(total int64, pageID, pageSize int, result any) PageData {
|
||||
return PageData{
|
||||
Total: total,
|
||||
PageID: pageID,
|
||||
PageSize: pageSize,
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
func Ok(ctx *gin.Context, data any) {
|
||||
ResponseJson(ctx, http.StatusOK, "ok", data)
|
||||
}
|
||||
|
||||
func Failed(ctx *gin.Context, message string) {
|
||||
ResponseJson(ctx, http.StatusInternalServerError, message, nil)
|
||||
}
|
||||
|
||||
func FailedWithCode(ctx *gin.Context, code int, message string) {
|
||||
ResponseJson(ctx, code, message, nil)
|
||||
}
|
||||
|
||||
func ResponseJson(ctx *gin.Context, code int, message string, data any) {
|
||||
ctx.JSON(code, response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
58
internal/pkg/gin/gu/validator.go
Normal file
58
internal/pkg/gin/gu/validator.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package gu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/go-playground/locales/en"
|
||||
"github.com/go-playground/locales/zh"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
enTranslations "github.com/go-playground/validator/v10/translations/en"
|
||||
chTranslations "github.com/go-playground/validator/v10/translations/zh"
|
||||
)
|
||||
|
||||
var trans ut.Translator
|
||||
|
||||
// SetValidatorTrans
|
||||
// local 通常取决于 http 请求头的 'Accept-Language'
|
||||
func SetValidatorTrans(local string) (err error) {
|
||||
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
||||
zhT := zh.New() //chinese
|
||||
enT := en.New() //english
|
||||
uni := ut.New(enT, zhT, enT)
|
||||
|
||||
var o bool
|
||||
trans, o = uni.GetTranslator(local)
|
||||
if !o {
|
||||
return fmt.Errorf("uni.GetTranslator(%s) failed", local)
|
||||
}
|
||||
//register translate
|
||||
// 注册翻译器
|
||||
switch local {
|
||||
case "en":
|
||||
err = enTranslations.RegisterDefaultTranslations(v, trans)
|
||||
case "zh":
|
||||
err = chTranslations.RegisterDefaultTranslations(v, trans)
|
||||
default:
|
||||
err = enTranslations.RegisterDefaultTranslations(v, trans)
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ValidatorErrors(ctx *gin.Context, err error) {
|
||||
var errs validator.ValidationErrors
|
||||
if errors.As(err, &errs) {
|
||||
es := gin.H{}
|
||||
for _, e := range errs {
|
||||
es[e.StructField()] = strings.Replace(e.Translate(trans), e.StructField(), "", -1)
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, es)
|
||||
}
|
||||
}
|
||||
@@ -2,56 +2,51 @@ package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"management/internal/erpserver/model/system"
|
||||
"management/internal/pkg/know"
|
||||
"management/internal/pkg/session"
|
||||
"management/internal/tasks"
|
||||
|
||||
"github.com/drhin/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Audit 改造后的中间件
|
||||
func Audit(sess session.Manager, log *logger.Logger, task tasks.TaskDistributor) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
start := time.Now()
|
||||
func Audit(log *logger.Logger, task tasks.TaskDistributor) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
start := time.Now()
|
||||
log.Info(start.Format(time.DateTime))
|
||||
ctx.Next()
|
||||
|
||||
user, err := sess.GetUser(ctx, know.StoreName)
|
||||
if err != nil {
|
||||
log.Error("获取用户会话失败", err)
|
||||
}
|
||||
payload := GetUser(ctx)
|
||||
if payload == nil {
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
if payload.ID == uuid.Nil {
|
||||
return
|
||||
}
|
||||
|
||||
if user.ID == 0 {
|
||||
return
|
||||
}
|
||||
logs := &tasks.PayloadConsumeAuditLog{
|
||||
//AuditLog: system.NewAuditLog(r, payload.Email, payload.OS, payload.Browser, start, time.Now()),
|
||||
}
|
||||
|
||||
payload := &tasks.PayloadConsumeAuditLog{
|
||||
AuditLog: system.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now()),
|
||||
}
|
||||
opts := []asynq.Option{
|
||||
asynq.MaxRetry(10),
|
||||
asynq.ProcessIn(1 * time.Second),
|
||||
asynq.Queue(tasks.QueueCritical),
|
||||
}
|
||||
|
||||
opts := []asynq.Option{
|
||||
asynq.MaxRetry(10),
|
||||
asynq.ProcessIn(1 * time.Second),
|
||||
asynq.Queue(tasks.QueueCritical),
|
||||
}
|
||||
c, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
c, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := task.DistributeTaskConsumeAuditLog(c, payload, opts...); err != nil {
|
||||
log.Error("distribute task failed", err,
|
||||
zap.String("type", "audit"),
|
||||
zap.Any("payload", payload),
|
||||
)
|
||||
}
|
||||
})
|
||||
if err := task.DistributeTaskConsumeAuditLog(c, logs, opts...); err != nil {
|
||||
log.Error("distribute task failed", err,
|
||||
zap.String("type", "audit"),
|
||||
zap.Any("payload", payload),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,53 +2,29 @@ package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"management/internal/erpserver/model/dto"
|
||||
v1 "management/internal/erpserver/service/v1"
|
||||
"management/internal/pkg/know"
|
||||
"management/internal/pkg/session"
|
||||
"management/internal/pkg/token"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
//// 高性能JSON库(全局初始化)
|
||||
//var json = jsoniter.ConfigFastest
|
||||
//
|
||||
//// 使用jsoniter优化菜单结构体序列化
|
||||
//func init() {
|
||||
// jsoniter.RegisterTypeEncoderFunc("dto.OwnerMenuDto",
|
||||
// func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
|
||||
// m := (*dto.OwnerMenuDto)(ptr)
|
||||
// stream.WriteObjectStart()
|
||||
// stream.WriteObjectField("id")
|
||||
// stream.WriteUint(uint(m.ID))
|
||||
// stream.WriteMore()
|
||||
// stream.WriteObjectField("url")
|
||||
// stream.WriteString(m.Url)
|
||||
// stream.WriteMore()
|
||||
// stream.WriteObjectField("parentId")
|
||||
// stream.WriteUint(uint(m.ParentID))
|
||||
// stream.WriteMore()
|
||||
// stream.WriteObjectField("isList")
|
||||
// stream.WriteBool(m.IsList)
|
||||
// stream.WriteObjectEnd()
|
||||
// }, nil)
|
||||
//}
|
||||
|
||||
var publicRoutes = map[string]bool{
|
||||
"/home.html": true,
|
||||
"/dashboard": true,
|
||||
"/system/menus": true,
|
||||
"/upload/img": true,
|
||||
"/upload/file": true,
|
||||
"/upload/multi_files": true,
|
||||
"/system/pear.json": true,
|
||||
"/logout": true,
|
||||
}
|
||||
|
||||
@@ -130,90 +106,111 @@ func getAllRoleIDs() []uint {
|
||||
return []uint{1, 2, 3, 4, 5}
|
||||
}
|
||||
|
||||
const (
|
||||
authorizationHeaderKey = "authorization"
|
||||
authorizationTypeBearer = "bearer"
|
||||
authorizationPayloadKey = "authorization_payload"
|
||||
)
|
||||
|
||||
func Authorize(
|
||||
sess session.Manager,
|
||||
tokenMaker token.Maker,
|
||||
menuService v1.MenuService,
|
||||
) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
path := r.URL.Path
|
||||
) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
|
||||
// 登陆检查
|
||||
user, err := sess.GetUser(ctx, know.StoreName)
|
||||
if err != nil || user.ID == 0 {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
authorizationHeader := ctx.GetHeader(authorizationHeaderKey)
|
||||
|
||||
// 公共路由放行
|
||||
if publicRoutes[path] {
|
||||
ctx = setUser(ctx, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
if len(authorizationHeader) == 0 {
|
||||
err := errors.New("authorization header is not provided")
|
||||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
n1 := time.Now()
|
||||
// 权限检查
|
||||
var menus map[string]*dto.OwnerMenuDto
|
||||
cacheKey := fmt.Sprintf("user_menus:%d", user.RoleID)
|
||||
shardIndex := getShardIndex(uint(user.RoleID))
|
||||
shard := menuCacheShards[shardIndex]
|
||||
mutex := shardMutexes[shardIndex]
|
||||
fields := strings.Fields(authorizationHeader)
|
||||
if len(fields) < 2 {
|
||||
err := errors.New("invalid authorization header format")
|
||||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 尝试无锁读缓存
|
||||
if cachedMenus, found := shard.Get(cacheKey); found {
|
||||
menus = cachedMenus.(map[string]*dto.OwnerMenuDto)
|
||||
log.Printf("listByRoleIDToMap (from cache): %s", time.Since(n1).String())
|
||||
} else {
|
||||
// 2. 单飞机制防止缓存击穿
|
||||
menusI, err, _ := flightGroup.Do(cacheKey, func() (interface{}, error) {
|
||||
// 3. 双检锁机制
|
||||
if cached, found := shard.Get(cacheKey); found {
|
||||
return cached, nil
|
||||
}
|
||||
authorizationType := strings.ToLower(fields[0])
|
||||
if authorizationType != authorizationTypeBearer {
|
||||
err := fmt.Errorf("unsupported authorization type %s", authorizationType)
|
||||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 查询数据库获取菜单数据
|
||||
maps, err := menuService.ListByRoleIDToMap(ctx, user.RoleID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessToken := fields[1]
|
||||
payload, err := tokenMaker.VerifyToken(accessToken, token.TypeAccessToken)
|
||||
if err != nil {
|
||||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 5. 写入缓存(加锁避免重复写入)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
// 用户校验完毕, 开始校验权限
|
||||
|
||||
// 6. 再次检查(避免其他goroutine已经写入)
|
||||
if cached, found := shard.Get(cacheKey); found {
|
||||
return cached, nil
|
||||
}
|
||||
path := ctx.Request.URL.Path
|
||||
|
||||
// 7. 设置分级过期时间
|
||||
expiration := getCacheExpiration(uint(user.RoleID))
|
||||
shard.Set(cacheKey, maps, expiration)
|
||||
return maps, nil
|
||||
})
|
||||
// 公共路由放行
|
||||
if publicRoutes[path] {
|
||||
ctx.Set(authorizationPayloadKey, payload)
|
||||
ctx.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
// 权限检查
|
||||
var menus map[string]*dto.OwnerMenuDto
|
||||
cacheKey := fmt.Sprintf("user_menus:%d", payload.RoleID)
|
||||
shardIndex := getShardIndex(uint(payload.RoleID))
|
||||
shard := menuCacheShards[shardIndex]
|
||||
mutex := shardMutexes[shardIndex]
|
||||
|
||||
// 1. 尝试无锁读缓存
|
||||
if cachedMenus, found := shard.Get(cacheKey); found {
|
||||
menus = cachedMenus.(map[string]*dto.OwnerMenuDto)
|
||||
} else {
|
||||
// 2. 单飞机制防止缓存击穿
|
||||
menusI, err, _ := flightGroup.Do(cacheKey, func() (interface{}, error) {
|
||||
// 3. 双检锁机制
|
||||
if cached, found := shard.Get(cacheKey); found {
|
||||
return cached, nil
|
||||
}
|
||||
menus = menusI.(map[string]*dto.OwnerMenuDto)
|
||||
log.Printf("listByRoleIDToMap (from DB, then cached): %s", time.Since(n1).String())
|
||||
}
|
||||
|
||||
if !hasPermission(menus, path) {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
// 4. 查询数据库获取菜单数据
|
||||
maps, err := menuService.ListByRoleIDToMap(ctx, int32(payload.RoleID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. 写入缓存(加锁避免重复写入)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
// 6. 再次检查(避免其他goroutine已经写入)
|
||||
if cached, found := shard.Get(cacheKey); found {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// 7. 设置分级过期时间
|
||||
expiration := getCacheExpiration(uint(payload.RoleID))
|
||||
shard.Set(cacheKey, maps, expiration)
|
||||
return maps, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||||
return
|
||||
}
|
||||
menus = menusI.(map[string]*dto.OwnerMenuDto)
|
||||
}
|
||||
|
||||
cur := getCurrentMenus(menus, path)
|
||||
if !hasPermission(menus, path) {
|
||||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx = setUser(ctx, user)
|
||||
ctx = setCurMenus(ctx, cur)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
setUser(ctx, payload)
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,25 +219,6 @@ func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
func getCurrentMenus(data map[string]*dto.OwnerMenuDto, path string) []dto.OwnerMenuDto {
|
||||
var res []dto.OwnerMenuDto
|
||||
|
||||
menu, ok := data[path]
|
||||
if !ok {
|
||||
return res
|
||||
}
|
||||
|
||||
for _, item := range data {
|
||||
if menu.IsList {
|
||||
if item.ParentID == menu.ID || item.ID == menu.ID {
|
||||
res = append(res, *item)
|
||||
}
|
||||
} else {
|
||||
if item.ParentID == menu.ParentID {
|
||||
res = append(res, *item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
func errorResponse(err error) gin.H {
|
||||
return gin.H{"error": err.Error()}
|
||||
}
|
||||
|
||||
@@ -5,26 +5,27 @@ import (
|
||||
"errors"
|
||||
|
||||
"management/internal/erpserver/model/dto"
|
||||
"management/internal/erpserver/model/system"
|
||||
"management/internal/pkg/sqldb"
|
||||
"management/internal/pkg/token"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type userKey struct{}
|
||||
|
||||
func setUser(ctx context.Context, usr system.AuthorizeUser) context.Context {
|
||||
return context.WithValue(ctx, userKey{}, usr)
|
||||
func setUser(ctx *gin.Context, payload *token.Payload) {
|
||||
ctx.Set(authorizationPayloadKey, payload)
|
||||
}
|
||||
|
||||
// GetUser returns the user from the context.
|
||||
func GetUser(ctx context.Context) system.AuthorizeUser {
|
||||
v, ok := ctx.Value(userKey{}).(system.AuthorizeUser)
|
||||
if !ok {
|
||||
return system.AuthorizeUser{}
|
||||
func GetUser(ctx *gin.Context) *token.Payload {
|
||||
value, exists := ctx.Get(authorizationHeaderKey)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
return v
|
||||
return value.(*token.Payload)
|
||||
}
|
||||
|
||||
type menuKey struct{}
|
||||
|
||||
64
internal/pkg/token/jwt_maker.go
Normal file
64
internal/pkg/token/jwt_maker.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const minSecretKeySize = 32
|
||||
|
||||
// JWTMaker is a JSON Web Token maker
|
||||
type JWTMaker struct {
|
||||
secretKey string
|
||||
}
|
||||
|
||||
// NewJWTMaker creates a new JWTMaker
|
||||
func NewJWTMaker(secretKey string) (Maker, error) {
|
||||
if len(secretKey) < minSecretKeySize {
|
||||
return nil, fmt.Errorf("invalid key size: must be at least %d characters", minSecretKeySize)
|
||||
}
|
||||
return &JWTMaker{secretKey}, nil
|
||||
}
|
||||
|
||||
// CreateToken creates a new token for a specific username and duration
|
||||
func (maker *JWTMaker) CreateToken(uuid uuid.UUID, username string, duration time.Duration, tokenType Type) (string, *Payload, error) {
|
||||
payload := NewPayload(uuid, username, duration, tokenType)
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||
token, err := jwtToken.SignedString([]byte(maker.secretKey))
|
||||
return token, payload, err
|
||||
}
|
||||
|
||||
// VerifyToken checks if the token is valid or not
|
||||
func (maker *JWTMaker) VerifyToken(token string, tokenType Type) (*Payload, error) {
|
||||
keyFunc := func(token *jwt.Token) (interface{}, error) {
|
||||
_, ok := token.Method.(*jwt.SigningMethodHMAC)
|
||||
if !ok {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
return []byte(maker.secretKey), nil
|
||||
}
|
||||
|
||||
jwtToken, err := jwt.ParseWithClaims(token, &Payload{}, keyFunc)
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrExpiredToken
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
payload, ok := jwtToken.Claims.(*Payload)
|
||||
if !ok {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
err = payload.Valid(tokenType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
16
internal/pkg/token/maker.go
Normal file
16
internal/pkg/token/maker.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Maker is an interface for managing tokens
|
||||
type Maker interface {
|
||||
// CreateToken creates a new token for a specific username and duration
|
||||
CreateToken(uuid uuid.UUID, username string, duration time.Duration, tokenType Type) (string, *Payload, error)
|
||||
|
||||
// VerifyToken checks if the token is valid or not
|
||||
VerifyToken(token string, tokenType Type) (*Payload, error)
|
||||
}
|
||||
54
internal/pkg/token/paseto_maker.go
Normal file
54
internal/pkg/token/paseto_maker.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aead/chacha20poly1305"
|
||||
"github.com/google/uuid"
|
||||
"github.com/o1egl/paseto"
|
||||
)
|
||||
|
||||
// PasetoMaker is a PASETO token maker
|
||||
type PasetoMaker struct {
|
||||
paseto *paseto.V2
|
||||
symmetricKey []byte
|
||||
}
|
||||
|
||||
// NewPasetoMaker creates a new PasetoMaker
|
||||
func NewPasetoMaker(symmetricKey string) (Maker, error) {
|
||||
if len(symmetricKey) != chacha20poly1305.KeySize {
|
||||
return nil, fmt.Errorf("invalid key size: must be exactly %d characters", chacha20poly1305.KeySize)
|
||||
}
|
||||
|
||||
maker := &PasetoMaker{
|
||||
paseto: paseto.NewV2(),
|
||||
symmetricKey: []byte(symmetricKey),
|
||||
}
|
||||
|
||||
return maker, nil
|
||||
}
|
||||
|
||||
// CreateToken creates a new token for a specific username and duration
|
||||
func (maker *PasetoMaker) CreateToken(uuid uuid.UUID, username string, duration time.Duration, tokenType Type) (string, *Payload, error) {
|
||||
payload := NewPayload(uuid, username, duration, tokenType)
|
||||
token, err := maker.paseto.Encrypt(maker.symmetricKey, payload, nil)
|
||||
return token, payload, err
|
||||
}
|
||||
|
||||
// VerifyToken checks if the token is valid or not
|
||||
func (maker *PasetoMaker) VerifyToken(token string, tokenType Type) (*Payload, error) {
|
||||
payload := &Payload{}
|
||||
|
||||
err := maker.paseto.Decrypt(token, maker.symmetricKey, payload, nil)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
err = payload.Valid(tokenType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
85
internal/pkg/token/payload.go
Normal file
85
internal/pkg/token/payload.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Different types of error returned by the VerifyToken function
|
||||
var (
|
||||
ErrInvalidToken = errors.New("token is invalid")
|
||||
ErrExpiredToken = errors.New("token has expired")
|
||||
)
|
||||
|
||||
type Type byte
|
||||
|
||||
const (
|
||||
TypeAccessToken = 1
|
||||
TypeRefreshToken = 2
|
||||
)
|
||||
|
||||
// Payload contains the payload data of the token
|
||||
type Payload struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
RoleID int `json:"role_id"`
|
||||
Type Type `json:"token_type"`
|
||||
Username string `json:"username"`
|
||||
IssuedAt time.Time `json:"issued_at"`
|
||||
ExpiredAt time.Time `json:"expired_at"`
|
||||
}
|
||||
|
||||
// NewPayload creates a new token payload with a specific username and duration
|
||||
func NewPayload(uuid uuid.UUID, username string, duration time.Duration, tokenType Type) *Payload {
|
||||
payload := &Payload{
|
||||
ID: uuid,
|
||||
Type: tokenType,
|
||||
Username: username,
|
||||
IssuedAt: time.Now(),
|
||||
ExpiredAt: time.Now().Add(duration),
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// Valid checks if the token payload is valid or not
|
||||
func (payload *Payload) Valid(tokenType Type) error {
|
||||
if payload.Type != tokenType {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
if time.Now().After(payload.ExpiredAt) {
|
||||
return ErrExpiredToken
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (payload *Payload) GetExpirationTime() (*jwt.NumericDate, error) {
|
||||
return &jwt.NumericDate{
|
||||
Time: payload.ExpiredAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (payload *Payload) GetIssuedAt() (*jwt.NumericDate, error) {
|
||||
return &jwt.NumericDate{
|
||||
Time: payload.IssuedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (payload *Payload) GetNotBefore() (*jwt.NumericDate, error) {
|
||||
return &jwt.NumericDate{
|
||||
Time: payload.IssuedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (payload *Payload) GetIssuer() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (payload *Payload) GetSubject() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (payload *Payload) GetAudience() (jwt.ClaimStrings, error) {
|
||||
return jwt.ClaimStrings{}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user