first commit

This commit is contained in:
2025-03-21 11:05:42 +08:00
commit 7dffc94035
1717 changed files with 724764 additions and 0 deletions

View File

@@ -0,0 +1,228 @@
// Package convertor 实现了一些函数来转换数据
package convertor
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"reflect"
"regexp"
"strconv"
"strings"
)
// ToBool 将字符串转换为布尔值
func ToBool(s string) (bool, error) {
return strconv.ParseBool(s)
}
// ToBytes 将接口转换为字节数
func ToBytes(value any) ([]byte, error) {
v := reflect.ValueOf(value)
switch value.(type) {
case int, int8, int16, int32, int64:
number := v.Int()
buf := bytes.NewBuffer([]byte{})
buf.Reset()
err := binary.Write(buf, binary.BigEndian, number)
return buf.Bytes(), err
case uint, uint8, uint16, uint32, uint64:
number := v.Uint()
buf := bytes.NewBuffer([]byte{})
buf.Reset()
err := binary.Write(buf, binary.BigEndian, number)
return buf.Bytes(), err
case float32:
number := float32(v.Float())
bits := math.Float32bits(number)
bys := make([]byte, 4)
binary.BigEndian.PutUint32(bys, bits)
return bys, nil
case float64:
number := v.Float()
bits := math.Float64bits(number)
bys := make([]byte, 8)
binary.BigEndian.PutUint64(bys, bits)
return bys, nil
case bool:
return strconv.AppendBool([]byte{}, v.Bool()), nil
case string:
return []byte(v.String()), nil
case []byte:
return v.Bytes(), nil
default:
newValue, err := json.Marshal(value)
return newValue, err
}
}
// ToChar 将字符串转换为char slice
func ToChar(s string) []string {
c := make([]string, 0)
if len(s) == 0 {
c = append(c, "")
}
for _, v := range s {
c = append(c, string(v))
}
return c
}
// ToString 将值转换为字符串
func ToString(value any) string {
res := ""
if value == nil {
return res
}
v := reflect.ValueOf(value)
switch value.(type) {
case float32, float64:
res = strconv.FormatFloat(v.Float(), 'f', -1, 64)
return res
case int, int8, int16, int32, int64:
res = strconv.FormatInt(v.Int(), 10)
return res
case uint, uint8, uint16, uint32, uint64:
res = strconv.FormatUint(v.Uint(), 10)
return res
case string:
res = v.String()
return res
case []byte:
res = string(v.Bytes())
return res
default:
newValue, _ := json.Marshal(value)
res = string(newValue)
return res
}
}
// ToJson 将值转换为有效的json字符串
func ToJson(value any) (string, error) {
res, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(res), nil
}
// ToFloat 将数值转换为float64,如果输入的不是float,则返回0.0和错误
func ToFloat(value any) (float64, error) {
v := reflect.ValueOf(value)
res := 0.0
err := fmt.Errorf("ToInt: unvalid interface type %T", value)
switch value.(type) {
case int, int8, int16, int32, int64:
res = float64(v.Int())
return res, nil
case uint, uint8, uint16, uint32, uint64:
res = float64(v.Uint())
return res, nil
case float32, float64:
res = v.Float()
return res, nil
case string:
res, err = strconv.ParseFloat(v.String(), 64)
if err != nil {
res = 0.0
}
return res, err
default:
return res, err
}
}
// ToInt 将数值转换为int64,如果输入的不是数字格式,则返回0和错误
func ToInt(value any) (int64, error) {
v := reflect.ValueOf(value)
var res int64
err := fmt.Errorf("ToInt: invalid interface type %T", value)
switch value.(type) {
case int, int8, int16, int32, int64:
res = v.Int()
return res, nil
case uint, uint8, uint16, uint32, uint64:
res = int64(v.Uint())
return res, nil
case float32, float64:
res = int64(v.Float())
return res, nil
case string:
res, err = strconv.ParseInt(v.String(), 0, 64)
if err != nil {
res = 0
}
return res, err
default:
return res, err
}
}
// StructToMap 将结构体转换为Map,只转换导出的结构体字段
// Map key的指定与结构字段标签`json`值相同
func StructToMap(value any) (map[string]any, error) {
v := reflect.ValueOf(value)
t := reflect.TypeOf(value)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("data type %T not support, shuld be struct or pointer to struct", value)
}
res := make(map[string]any)
fieldNum := t.NumField()
pattern := `^[A-Z]`
regex := regexp.MustCompile(pattern)
for i := 0; i < fieldNum; i++ {
name := t.Field(i).Name
tag := t.Field(i).Tag.Get("json")
if regex.MatchString(name) && tag != "" {
// res[name] = v.Field(i).Interface()
res[tag] = v.Field(i).Interface()
}
}
return res, nil
}
// ColorHexToRGB 将十六进制颜色转换为RGB颜色
func ColorHexToRGB(colorHex string) (red, green, blue int) {
colorHex = strings.TrimPrefix(colorHex, "#")
color64, err := strconv.ParseInt(colorHex, 16, 32)
if err != nil {
return
}
color := int(color64)
return color >> 16, (color & 0x00FF00) >> 8, color & 0x0000FF
}
// ColorRGBToHex 将RGB颜色转换为十六进制颜色
func ColorRGBToHex(red, green, blue int) string {
r := strconv.FormatInt(int64(red), 16)
g := strconv.FormatInt(int64(green), 16)
b := strconv.FormatInt(int64(blue), 16)
if len(r) == 1 {
r = "0" + r
}
if len(g) == 1 {
g = "0" + g
}
if len(b) == 1 {
b = "0" + b
}
return "#" + r + g + b
}

View File

@@ -0,0 +1,11 @@
package convertor
import "strconv"
func ConvertInt[T int | int16 | int32 | int64](value string, defaultValue T) T {
i, err := strconv.Atoi(value)
if err != nil {
return defaultValue
}
return T(i)
}

View File

@@ -0,0 +1,18 @@
package convertor
import "github.com/jackc/pgx/v5/pgtype"
func NumericToFloat64(num pgtype.Numeric) float64 {
if !num.Valid {
return 0
}
f1, err := num.Float64Value()
if err != nil {
return 0
}
f2, err := f1.Value()
if err != nil {
return 0
}
return f2.(float64)
}

View File

@@ -0,0 +1,72 @@
package crypto
import (
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/scrypt"
)
// ******************** scrypt ********************
// ScryptHashPassword scrypt 加密
// password 原始密码
func ScryptHashPassword(password string) (string, error) {
// example for making salt - https://play.golang.org/p/_Aw6WeWC42I
salt := make([]byte, 32)
_, err := rand.Read(salt)
if err != nil {
return "", err
}
// using recommended cost parameters from - https://godoc.org/golang.org/x/crypto/scrypt
hash, err := scrypt.Key([]byte(password), salt, 32768, 8, 1, 32)
if err != nil {
return "", err
}
// return hex-encoded string with salt appended to password
hashedPW := fmt.Sprintf("%s.%s", hex.EncodeToString(hash), hex.EncodeToString(salt))
return hashedPW, nil
}
// ScryptComparePassword 判断密码是否正确
// storedPassword 加密密码
// suppliedPassword 原始密码
func ScryptComparePassword(storedPassword string, suppliedPassword string) (bool, error) {
pwdSalt := strings.Split(storedPassword, ".")
// check supplied password salted with hash
salt, err := hex.DecodeString(pwdSalt[1])
if err != nil {
return false, fmt.Errorf("unable to verify user password")
}
hash, err := scrypt.Key([]byte(suppliedPassword), salt, 32768, 8, 1, 32)
if err != nil {
return false, err
}
return hex.EncodeToString(hash) == pwdSalt[0], nil
}
// ******************** bcrypt ********************
// BcryptHashPassword bcrypt 加密
// password 原始密码
func BcryptHashPassword(password string) ([]byte, error) {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
return hashedPassword, nil
}
// BcryptComparePassword 判断密码是否正确
// hashedPassword 加密密码
// password 原始密码
func BcryptComparePassword(hashedPassword []byte, password string) error {
return bcrypt.CompareHashAndPassword(hashedPassword, []byte(password))
}

View File

@@ -0,0 +1,233 @@
package fetcher
import (
"bytes"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"time"
)
// Get get请求
func Get(url string, timeout time.Duration) ([]byte, int, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, 0, fmt.Errorf("[get] new request err: %v", err)
}
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("[get] client do err: %v", err)
}
defer resp.Body.Close()
status := resp.StatusCode
res, err := io.ReadAll(resp.Body)
if err != nil {
return nil, status, fmt.Errorf("read response err: %v", err)
}
return res, status, nil
}
// GetString 请求
func GetString(url string, parameter map[string]string, timeout time.Duration) ([]byte, int, error) {
byteParameter := new(bytes.Buffer)
w := multipart.NewWriter(byteParameter)
for k, v := range parameter {
w.WriteField(k, v)
}
w.Close()
request, err := http.NewRequest("GET", url, byteParameter)
if err != nil {
return nil, 0, fmt.Errorf("[getstring] new request err: %v", err)
}
request.Header.Set("Content-Type", w.FormDataContentType())
client := &http.Client{Timeout: time.Second * timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, err
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(response.Body)
status := response.StatusCode
resp, err := io.ReadAll(response.Body)
if err != nil {
return nil, status, err
}
return resp, status, nil
}
// GetJson application/json get 请求
func GetJson(url string, parameter []byte, timeout time.Duration) ([]byte, int, error) {
byteParameter := bytes.NewBuffer(parameter)
req, err := http.NewRequest("GET", url, byteParameter)
if err != nil {
return nil, 0, fmt.Errorf("[getjson] new request err: %v", err)
}
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("[getjson] client do err: %v", err)
}
defer resp.Body.Close()
status := resp.StatusCode
res, err := io.ReadAll(resp.Body)
if err != nil {
return nil, status, fmt.Errorf("read response err: %v", err)
}
return res, status, nil
}
// GetJsonWithToken Get json with token 请求
func GetJsonWithToken(url string, token string, timeout time.Duration) ([]byte, int, error) {
request, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, 0, err
}
request.Header.Set("Content-type", "application/json")
request.Header.Set("Access-Token", token)
client := &http.Client{Timeout: timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, errors.New("请求网络错误")
}
defer response.Body.Close()
status := response.StatusCode
resp, err := io.ReadAll(response.Body)
if err != nil {
return nil, status, err
}
return resp, status, nil
}
// PostJson application/json post 请求
func PostJson(url string, parameter []byte, timeout time.Duration) ([]byte, int, error) {
byteParameter := bytes.NewBuffer(parameter)
request, err := http.NewRequest("POST", url, byteParameter)
if err != nil {
return nil, 0, fmt.Errorf("[postjson] new request err: %v", err)
}
request.Header.Set("Content-type", "application/json")
client := &http.Client{Timeout: timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, fmt.Errorf("[postjson] client do err: %v", err)
}
defer response.Body.Close()
status := response.StatusCode
all, err := io.ReadAll(response.Body)
if err != nil {
return nil, status, fmt.Errorf("read response err: %v", err)
}
return all, status, nil
}
// PostString application/x-www-form-urlencoded 请求
func PostString(url string, parameter []byte, timeout time.Duration) ([]byte, int, error) {
byteParameter := bytes.NewBuffer(parameter)
request, err := http.NewRequest("POST", url, byteParameter)
if err != nil {
return nil, 0, fmt.Errorf("[poststring] new request err: %v", err)
}
request.Header.Set("Content-type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, fmt.Errorf("[poststring] client do err: %v", err)
}
defer response.Body.Close()
status := response.StatusCode
all, err := io.ReadAll(response.Body)
if err != nil {
return nil, status, fmt.Errorf("read response err: %v", err)
}
return all, status, nil
}
// PostJsonWithToken Post json with token 请求
func PostJsonWithToken(url string, parameter []byte, token string, timeout time.Duration) ([]byte, int, error) {
bufParameter := bytes.NewBuffer(parameter)
request, err := http.NewRequest("POST", url, bufParameter)
if err != nil {
return nil, 0, err
}
request.Header.Set("Content-type", "application/json")
request.Header.Set("Access-Token", token)
client := &http.Client{Timeout: timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, fmt.Errorf("请求网络错误: %v", err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(response.Body)
stauts := response.StatusCode
resp, err := io.ReadAll(response.Body)
if err != nil {
return nil, stauts, errors.New("网络请求结果读取失败")
}
return resp, stauts, nil
}
// PostJsonWithBearerToken Post json with bearer token 请求
func PostJsonWithBearerToken(url string, parameter []byte, token string, timeout time.Duration) ([]byte, int, error) {
bufParameter := bytes.NewBuffer(parameter)
request, err := http.NewRequest("POST", url, bufParameter)
if err != nil {
return nil, 0, err
}
request.Header.Set("Content-type", "application/json")
request.Header.Set("Authorization", fmt.Sprintf("bearer %s", token))
client := &http.Client{Timeout: timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, fmt.Errorf("请求网络错误: %v", err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(response.Body)
status := response.StatusCode
resp, err := io.ReadAll(response.Body)
if err != nil {
return nil, status, errors.New("网络请求结果读取失败")
}
return resp, status, nil
}
// func determiningEncoding(r *bufio.Reader) encoding.Encoding {
// b, err := r.Peek(1024)
// if err != nil && err != io.EOF {
// log.Printf("Fetcher error: %v", err)
// return unicode.UTF8
// }
// e, _, _ := charset.DetermineEncoding(b, "")
// return e
// }

20
internal/pkg/file/file.go Normal file
View File

@@ -0,0 +1,20 @@
package file
import "os"
func Exists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
func Mkdir(path string) error {
return os.MkdirAll(path, os.ModePerm)
}

159
internal/pkg/file/upload.go Normal file
View File

@@ -0,0 +1,159 @@
package file
import (
"errors"
"fmt"
"io"
"mime/multipart"
"os"
"path"
"time"
"github.com/h2non/filetype"
gonanoid "github.com/matoous/go-nanoid/v2"
)
const (
MaxImageSize = 10 << 20 // 10 MB
MaxFileSize = 50 << 20 // 50 MB
)
var ErrUnsupported = errors.New("文件格式不支持")
type FileType int
const (
ALL FileType = 0
IMG FileType = 1
)
func UploadFilename(filepath string, t FileType) (string, error) {
fileOpen, err := os.Open(filepath)
if err != nil {
return "", err
}
defer fileOpen.Close()
fileBytes, err := io.ReadAll(fileOpen)
if err != nil {
return "", errors.New("failed to read file")
}
if t == IMG {
// 判断是不是图片
if !filetype.IsImage(fileBytes) {
return "", ErrUnsupported
}
}
kind, err := filetype.Match(fileBytes)
if err != nil {
return "", err
}
if kind == filetype.Unknown {
return "", ErrUnsupported
}
// 使用 filetype 判断类型后已经去读了一些bytes了
// 要恢复文件读取位置
_, err = fileOpen.Seek(0, io.SeekStart)
if err != nil {
return "", err
}
dir := GetPath()
exist, _ := Exists(dir)
if !exist {
if err := Mkdir(dir); err != nil {
return "", err
}
}
filename := GenFilename(kind.Extension)
path := path.Join(dir, filename)
f, err := os.Create(path)
if err != nil {
return "", err
}
defer f.Close()
_, err = io.Copy(f, fileOpen)
if err != nil {
return "", err
}
return "/" + path, nil
}
func UploadFile(file *multipart.FileHeader, t FileType) (string, error) {
if file.Size > MaxFileSize {
return "", errors.New("failed to receive file too large")
}
fileOpen, err := file.Open()
if err != nil {
return "", errors.New("fialed to open file")
}
defer fileOpen.Close()
fileBytes, err := io.ReadAll(fileOpen)
if err != nil {
return "", errors.New("failed to read file")
}
if t == IMG {
// 判断是不是图片
if !filetype.IsImage(fileBytes) {
return "", ErrUnsupported
}
}
kind, err := filetype.Match(fileBytes)
if err != nil {
return "", err
}
if kind == filetype.Unknown {
return "", ErrUnsupported
}
// 使用 filetype 判断类型后已经去读了一些bytes了
// 要恢复文件读取位置
_, err = fileOpen.Seek(0, io.SeekStart)
if err != nil {
return "", err
}
dir := GetPath()
exist, _ := Exists(dir)
if !exist {
if err := Mkdir(dir); err != nil {
return "", err
}
}
filename := GenFilename(kind.Extension)
path := path.Join(dir, filename)
f, err := os.Create(path)
if err != nil {
return "", err
}
defer f.Close()
_, err = io.Copy(f, fileOpen)
if err != nil {
return "", err
}
return "/" + path, nil
}
func GetPath() string {
return fmt.Sprintf("upload/%s/%s/%s/", time.Now().Format("2006"), time.Now().Format("01"), time.Now().Format("02"))
}
func GenFilename(ext string) string {
id, _ := gonanoid.New()
return fmt.Sprintf("%s.%s", id, ext)
}

View File

@@ -0,0 +1,26 @@
package gu
import (
"net/http"
"github.com/gin-gonic/gin"
)
func Cors() gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Headers", "Content-Type, AccessToken, X-CSRF-Token, Authorization, Token")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type")
c.Header("Access-Control-Allow-Credentials", "true")
// 放行所有OPTIONS方法
if method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
}
// 处理请求
c.Next()
}
}

View File

@@ -0,0 +1,40 @@
package gu
import (
"net/http"
"github.com/gin-gonic/gin"
)
type response struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
}
type PageData struct {
Total int64 `json:"total"`
PageID int32 `json:"page_id"`
PageSize int32 `json:"page_size"`
Result any `json:"result"`
}
func Ok(ctx *gin.Context, data any) {
ResponseJson(ctx, http.StatusOK, "ok", data)
}
func Failed(ctx *gin.Context, message string) {
ResponseJson(ctx, http.StatusInternalServerError, message, nil)
}
func FailedWithCode(ctx *gin.Context, code int, message string) {
ResponseJson(ctx, code, message, nil)
}
func ResponseJson(ctx *gin.Context, code int, message string, data any) {
ctx.JSON(code, response{
Code: code,
Message: message,
Data: data,
})
}

View File

@@ -0,0 +1,57 @@
package gu
import (
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/locales/en"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
enTranslations "github.com/go-playground/validator/v10/translations/en"
chTranslations "github.com/go-playground/validator/v10/translations/zh"
)
var trans ut.Translator
// loca 通常取决于 http 请求头的 'Accept-Language'
func SetValidatorTrans(local string) (err error) {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
zhT := zh.New() // chinese
enT := en.New() // english
uni := ut.New(enT, zhT, enT)
var o bool
trans, o = uni.GetTranslator(local)
if !o {
return fmt.Errorf("uni.GetTranslator(%s) failed", local)
}
// register translate
// 注册翻译器
switch local {
case "en":
err = enTranslations.RegisterDefaultTranslations(v, trans)
case "zh":
err = chTranslations.RegisterDefaultTranslations(v, trans)
default:
err = enTranslations.RegisterDefaultTranslations(v, trans)
}
return
}
return
}
func ValidatorErrors(ctx *gin.Context, err error) {
if errors, ok := err.(validator.ValidationErrors); ok {
errs := gin.H{}
for _, e := range errors {
errs[e.StructField()] = strings.Replace(e.Translate(trans), e.StructField(), "", -1)
}
ctx.JSON(http.StatusBadRequest, errs)
} else {
ctx.JSON(http.StatusBadRequest, err)
}
}

View File

@@ -0,0 +1,33 @@
package logger
import (
"os"
"time"
"management/internal/config"
"github.com/natefinch/lumberjack"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func Init() {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
logRotate := &lumberjack.Logger{
Filename: "./log/run.log", // 日志文件的位置
MaxSize: 10, // 在进行切割之前,日志文件的最大大小(以MB为单位)
MaxBackups: 100, // 保留旧文件的最大个数
MaxAge: 30, // 保留旧文件的最大天数
Compress: true,
}
zerolog.TimeFieldFormat = time.DateTime
log.Logger = log.With().Caller().Logger()
if config.File.App.Prod {
log.Logger = log.Output(logRotate)
} else {
consoleWriter := zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.DateTime}
multi := zerolog.MultiLevelWriter(consoleWriter, logRotate)
log.Logger = log.Output(multi)
}
}

53
internal/pkg/rand/rand.go Normal file
View File

@@ -0,0 +1,53 @@
package rand
import (
"crypto/rand"
"encoding/base64"
"fmt"
randv2 "math/rand/v2"
"strings"
)
const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func Bytes(n int) ([]byte, error) {
b := make([]byte, n)
nRead, err := rand.Read(b)
if err != nil {
return nil, fmt.Errorf("bytes: %w", err)
}
if nRead < n {
return nil, fmt.Errorf("bytes: didn't read enough random bytes")
}
return b, nil
}
func String(n int) (string, error) {
b, err := Bytes(n)
if err != nil {
return "", fmt.Errorf("string: %w", err)
}
return base64.URLEncoding.EncodeToString(b), nil
}
func RandomInt(n int) string {
letters := []byte("0123456789")
l := len(letters)
result := make([]byte, n)
for i := range result {
result[i] = letters[randv2.IntN(l)]
}
return string(result)
}
func RandomString(n int) string {
var sb strings.Builder
k := len(alphabet)
for i := 0; i < n; i++ {
c := alphabet[randv2.IntN(k)]
sb.WriteByte(c)
}
return sb.String()
}

126
internal/pkg/redis/redis.go Normal file
View File

@@ -0,0 +1,126 @@
package redis
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"time"
"management/internal/config"
"github.com/redis/go-redis/v9"
)
var (
engine *redis.Client
ErrRedisKeyNotFound = errors.New("redis key not found")
)
// func GetRedis() *redis.Client {
// return rd
// }
func Init() error {
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", config.File.Redis.Host, config.File.Redis.Port),
Password: config.File.Redis.Password,
DB: config.File.Redis.DB,
})
_, err := rdb.Ping(context.Background()).Result()
if err != nil {
return err
}
engine = rdb
return nil
}
func Encode(a any) ([]byte, error) {
var b bytes.Buffer
if err := gob.NewEncoder(&b).Encode(a); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// Set 设置值
func Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
return engine.Set(ctx, key, value, expiration).Err()
}
// Del 删除键值
func Del(ctx context.Context, keys ...string) error {
return engine.Del(ctx, keys...).Err()
}
// Get 获取值
func Get(ctx context.Context, key string) (string, error) {
val, err := engine.Get(ctx, key).Result()
if err == redis.Nil {
return "", ErrRedisKeyNotFound
} else if err != nil {
return "", fmt.Errorf("cannot get value with:[%s]: %v", key, err)
} else {
return val, nil
}
}
// GetBytes 获取值
func GetBytes(ctx context.Context, key string) ([]byte, error) {
val, err := engine.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, ErrRedisKeyNotFound
} else if err != nil {
return nil, fmt.Errorf("cannot get value with:[%s]: %v", key, err)
} else {
return val, nil
}
}
func Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd {
return engine.Scan(ctx, cursor, match, count)
}
func Keys(ctx context.Context, pattern string) ([]string, error) {
return engine.Keys(ctx, pattern).Result()
}
func ListKeys(ctx context.Context, pattern string, pageID int, pageSize int) ([]string, int, error) {
all, err := engine.Keys(ctx, pattern).Result()
if err != nil {
return nil, 0, err
}
count := len(all)
if count == 0 {
return nil, 0, err
}
// 使用SCAN命令分页获取键
cursor := uint64(0)
var keys []string
for {
var scanResult []string
var err error
scanResult, cursor, err = engine.Scan(ctx, cursor, pattern, int64(pageSize)).Result()
if err != nil {
return nil, count, err
}
keys = append(keys, scanResult...)
if cursor == 0 {
break
}
}
startIndex := (pageID - 1) * pageSize
endIndex := startIndex + pageSize
if startIndex >= len(keys) {
return nil, count, nil
}
if endIndex > len(keys) {
endIndex = len(keys)
}
return keys[startIndex:endIndex], count, nil
}

View File

@@ -0,0 +1,83 @@
package session
import (
"context"
"time"
"management/internal/pkg/redis"
)
var (
storePrefix = "scs:session:"
ctx = context.Background()
DefaultRedisStore = newRedisStore()
)
type redisStore struct{}
func newRedisStore() *redisStore {
return &redisStore{}
}
// Delete should remove the session token and corresponding data from the
// session store. If the token does not exist then Delete should be a no-op
// and return nil (not an error).
func (s *redisStore) Delete(token string) error {
return redis.Del(ctx, storePrefix+token)
}
// Find should return the data for a session token from the store. If the
// session token is not found or is expired, the found return value should
// be false (and the err return value should be nil). Similarly, tampered
// or malformed tokens should result in a found return value of false and a
// nil err value. The err return value should be used for system errors only.
func (s *redisStore) Find(token string) (b []byte, found bool, err error) {
val, err := redis.GetBytes(ctx, storePrefix+token)
if err != nil {
return nil, false, err
} else {
return val, true, nil
}
}
// Commit should add the session token and data to the store, with the given
// expiry time. If the session token already exists, then the data and
// expiry time should be overwritten.
func (s *redisStore) Commit(token string, b []byte, expiry time.Time) error {
// TODO: 这边可以调整时间
exp, err := time.ParseInLocation(time.DateTime, time.Now().Format("2006-01-02")+" 23:59:59", time.Local)
if err != nil {
return err
}
t := time.Now()
expired := exp.Sub(t)
return redis.Set(ctx, storePrefix+token, b, expired)
}
// All should return a map containing data for all active sessions (i.e.
// sessions which have not expired). The map key should be the session
// token and the map value should be the session data. If no active
// sessions exist this should return an empty (not nil) map.
func (s *redisStore) All() (map[string][]byte, error) {
sessions := make(map[string][]byte)
iter := redis.Scan(ctx, 0, storePrefix+"*", 0).Iterator()
for iter.Next(ctx) {
key := iter.Val()
token := key[len(storePrefix):]
data, exists, err := s.Find(token)
if err != nil {
return nil, err
}
if exists {
sessions[token] = data
}
}
if err := iter.Err(); err != nil {
return nil, err
}
return sessions, nil
}

View File

@@ -0,0 +1,59 @@
package session
import (
"context"
"net/http"
"time"
"management/internal/config"
db "management/internal/db/sqlc"
"github.com/alexedwards/scs/pgxstore"
"github.com/alexedwards/scs/v2"
)
var sessionManager *scs.SessionManager
func Init() {
sessionManager = scs.New()
sessionManager.Lifetime = 24 * time.Hour
sessionManager.IdleTimeout = 2 * time.Hour
sessionManager.Cookie.Name = "token"
sessionManager.Cookie.HttpOnly = true
sessionManager.Cookie.Persist = true
sessionManager.Cookie.SameSite = http.SameSiteStrictMode
sessionManager.Cookie.Secure = config.File.App.Prod
// postgres
// github.com/alexedwards/scs/postgresstore
// sessionManager.Store = postgresstore.New(db)
// pgx
// github.com/alexedwards/scs/pgxstore
sessionManager.Store = pgxstore.New(db.Engine.Pool())
// redis
// sessionManager.Store = newRedisStore()
}
func Destroy(ctx context.Context) error {
return sessionManager.Destroy(ctx)
}
func LoadAndSave(next http.Handler) http.Handler {
return sessionManager.LoadAndSave(next)
}
func Put(ctx context.Context, key string, val interface{}) {
sessionManager.Put(ctx, key, val)
}
func GetBytes(ctx context.Context, key string) []byte {
return sessionManager.GetBytes(ctx, key)
}
func Exists(ctx context.Context, key string) bool {
return sessionManager.Exists(ctx, key)
}
func RenewToken(ctx context.Context) error {
return sessionManager.RenewToken(ctx)
}

41
internal/pkg/smb/smb.go Normal file
View File

@@ -0,0 +1,41 @@
package smb
import (
"io/fs"
"net"
"management/internal/config"
"github.com/hirochachacha/go-smb2"
)
var FS fs.FS
func Init() error {
conn, err := net.Dial("tcp", config.File.Smb.Host+":445")
if err != nil {
return err
}
defer conn.Close()
d := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: config.File.Smb.Name,
Password: config.File.Smb.Pass,
},
}
s, err := d.Dial(conn)
if err != nil {
return err
}
defer s.Logoff()
fs, err := s.Mount(config.File.Smb.Mount)
if err != nil {
return err
}
defer fs.Umount()
FS = fs.DirFS(".")
return nil
}

View File

@@ -0,0 +1,20 @@
package snowflake
import (
"github.com/bwmarrin/snowflake"
)
var node *snowflake.Node
func Init() error {
n, err := snowflake.NewNode(1)
if err != nil {
return err
}
node = n
return nil
}
func GetId() int64 {
return node.Generate().Int64()
}

View File

@@ -0,0 +1,25 @@
package sqids
import "github.com/sqids/sqids-go"
var engine *sqids.Sqids
func Init() error {
var err error
engine, err = sqids.New(sqids.Options{
MinLength: 12,
Alphabet: "AvjM1lkB8N6cuhs2oFxnXyYDwCmLGI7JOzt9fW3HRgb5ZQrqaU04TePSVKdpiE",
})
if err != nil {
return err
}
return nil
}
func Encode(ids []uint64) (string, error) {
return engine.Encode(ids)
}
func Decode(s string) []uint64 {
return engine.Decode(s)
}

View File

@@ -0,0 +1,40 @@
package strutil
import "strings"
// splitWordsToLower 将一个字符串按大写字母分割成若干个字符串
func splitWordsToLower(s string) []string {
var res []string
upperIndexes := upperIndex(s)
l := len(upperIndexes)
if upperIndexes == nil || l == 0 {
if s != "" {
res = append(res, s)
}
return res
}
for i := 0; i < l; i++ {
if i < l-1 {
res = append(res, strings.ToLower(s[upperIndexes[i]:upperIndexes[i+1]]))
} else {
res = append(res, strings.ToLower(s[upperIndexes[i]:]))
}
}
return res
}
// upperIndex 获得一个int slice,其元素是一个字符串的所有大写字母索引
func upperIndex(s string) []int {
var res []int
for i := 0; i < len(s); i++ {
if 64 < s[i] && s[i] < 91 {
res = append(res, i)
}
}
if len(s) > 0 && res != nil && res[0] != 0 {
res = append([]int{0}, res...)
}
return res
}

View File

@@ -0,0 +1,296 @@
// Package strutil 实现了一些函数来操作字符串
package strutil
import (
"regexp"
"strings"
"unicode"
"unicode/utf8"
)
// CamelCase 转换字符串到驼峰法(CamelCase)
func CamelCase(s string) string {
if len(s) == 0 {
return ""
}
res := ""
blankSpace := " "
regex, _ := regexp.Compile("[-_&]+")
ss := regex.ReplaceAllString(s, blankSpace)
for i, v := range strings.Split(ss, blankSpace) {
vv := []rune(v)
if i == 0 {
if vv[i] >= 65 && vv[i] <= 96 {
vv[0] += 32
}
res += string(vv)
} else {
res += Capitalize(v)
}
}
return res
}
// Capitalize 将一个字符串的第一个字符转换为大写,其余的转换为小写
func Capitalize(s string) string {
if len(s) == 0 {
return ""
}
out := make([]rune, len(s))
for i, v := range s {
if i == 0 {
out[i] = unicode.ToUpper(v)
} else {
out[i] = unicode.ToLower(v)
}
}
return string(out)
}
// UpperFirst 将字符串的第一个字符转换为大写
func UpperFirst(s string) string {
if len(s) == 0 {
return ""
}
r, size := utf8.DecodeRuneInString(s)
r = unicode.ToUpper(r)
return string(r) + s[size:]
}
// LowerFirst 将字符串的第一个字符转换为小写
func LowerFirst(s string) string {
if len(s) == 0 {
return ""
}
r, size := utf8.DecodeRuneInString(s)
r = unicode.ToLower(r)
return string(r) + s[size:]
}
// PadEnd 如果字符串比尺寸短,则将其垫在右侧
// 填充字符如果超过大小,将被截断
func PadEnd(source string, size int, padStr string) string {
len1 := len(source)
len2 := len(padStr)
if len1 >= size {
return source
}
fill := ""
if len2 >= size-len1 {
fill = padStr[0 : size-len1]
} else {
fill = strings.Repeat(padStr, size-len1)
}
return source + fill[0:size-len1]
}
// PadStart 如果字符串比尺寸短,则将其垫在左侧
// 填充字符如果超过大小,将被截断
func PadStart(source string, size int, padStr string) string {
len1 := len(source)
len2 := len(padStr)
if len1 >= size {
return source
}
fill := ""
if len2 >= size-len1 {
fill = padStr[0 : size-len1]
} else {
fill = strings.Repeat(padStr, size-len1)
}
return fill[0:size-len1] + source
}
// KebabCase 将字符串转为短横线隔开式(kebab-case)
func KebabCase(s string) string {
if len(s) == 0 {
return ""
}
regex := regexp.MustCompile(`[\W|_]+`)
blankSpace := " "
match := regex.ReplaceAllString(s, blankSpace)
rs := strings.Split(match, blankSpace)
var res []string
for _, v := range rs {
splitWords := splitWordsToLower(v)
if len(splitWords) > 0 {
res = append(res, splitWords...)
}
}
return strings.Join(res, "-")
}
// SnakeCase 将字符串转为蛇形命名(snake_case)
func SnakeCase(s string) string {
if len(s) == 0 {
return ""
}
regex := regexp.MustCompile(`[\W|_]+`)
blankSpace := " "
match := regex.ReplaceAllString(s, blankSpace)
rs := strings.Split(match, blankSpace)
var res []string
for _, v := range rs {
splitWords := splitWordsToLower(v)
if len(splitWords) > 0 {
res = append(res, splitWords...)
}
}
return strings.Join(res, "_")
}
// Before 在字符首次出现的位置之前,在源字符串中创建子串
func Before(s, char string) string {
if s == "" || char == "" {
return s
}
i := strings.Index(s, char)
return s[0:i]
}
// BeforeLast 在字符最后出现的位置之前,在源字符串中创建子串
func BeforeLast(s, char string) string {
if s == "" || char == "" {
return s
}
i := strings.LastIndex(s, char)
return s[0:i]
}
// After 在字符首次出现的位置后,在源字符串中创建子串
func After(s, char string) string {
if s == "" || char == "" {
return s
}
i := strings.Index(s, char)
return s[i+len(char):]
}
// AfterLast 在字符最后出现的位置后,在源字符串中创建子串
func AfterLast(s, char string) string {
if s == "" || char == "" {
return s
}
i := strings.LastIndex(s, char)
return s[i+len(char):]
}
// IsString 检查值的数据类型是否为字符串
func IsString(v any) bool {
if v == nil {
return false
}
switch v.(type) {
case string:
return true
default:
return false
}
}
// ReverseStr 返回字符顺序与给定字符串相反的字符串
func ReverseStr(s string) string {
r := []rune(s)
for i, j := 0, len(r)-1; i < j; i, j = i+1, j-1 {
r[i], r[j] = r[j], r[i]
}
return string(r)
}
// Wrap 用另一个字符串包住一个字符串
func Wrap(str string, wrapWith string) string {
if str == "" || wrapWith == "" {
return str
}
var sb strings.Builder
sb.WriteString(wrapWith)
sb.WriteString(str)
sb.WriteString(wrapWith)
return sb.String()
}
// Unwrap 从另一个字符串中解开一个给定的字符串,将改变str值
func Unwrap(str string, wrapToken string) string {
if str == "" || wrapToken == "" {
return str
}
firstIndex := strings.Index(str, wrapToken)
lastIndex := strings.LastIndex(str, wrapToken)
if firstIndex == 0 && lastIndex > 0 && lastIndex <= len(str)-1 {
if len(wrapToken) <= lastIndex {
str = str[len(wrapToken):lastIndex]
}
}
return str
}
// RemoveHTML 去除字符串中的 html, js
func RemoveHTML(str string) string {
if len(str) > 0 {
// 删除脚本
reg := regexp.MustCompile(`([\r\n])[\s]+`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`<script[^>]*?>.*?</script>`)
str = reg.ReplaceAllString(str, "")
// 删除HTML
reg = regexp.MustCompile(`<(.[^>]*)>`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`([\r\n])[\s]+`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`-->`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`<!--.*`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`&(quot|#34);`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`&(amp|#38);`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`&(lt|#60);`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`&(gt|#62);`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`&(nbsp|#160);`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`&(iexcl|#161);`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`&(cent|#162);`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`&(pound|#163);`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`&(copy|#169);`)
str = reg.ReplaceAllString(str, "")
reg = regexp.MustCompile(`&#(\d+);`)
str = reg.ReplaceAllString(str, "")
str = strings.ReplaceAll(str, "<", "")
str = strings.ReplaceAll(str, ">", "")
str = strings.ReplaceAll(str, "\n", "")
str = strings.ReplaceAll(str, " ", "")
str = strings.ReplaceAll(str, " ", "")
return str
}
return ""
}

View File

@@ -0,0 +1,59 @@
package token
import (
"fmt"
"time"
tk "github.com/golang-jwt/jwt/v5"
)
const minSecretKeySize = 32
// JWTMaker JSON Web Token
type JWTMaker struct {
secretKey string
}
// NewJWTMaker 创建一个新的JWTMaker
func NewJWTMaker(secretKey string) error {
if len(secretKey) < minSecretKeySize {
return fmt.Errorf("invalid key size: must be at least %d characters", minSecretKeySize)
}
engine = &JWTMaker{secretKey}
return nil
}
// CreateToken 根据用户名和时间创建一个新的token
func (maker *JWTMaker) CreateToken(id string, username string, duration time.Duration) (string, *Payload, error) {
payload, err := NewPayload(id, username, duration)
if err != nil {
return "", payload, err
}
jwtToken := tk.NewWithClaims(tk.SigningMethodHS256, payload)
token, err := jwtToken.SignedString([]byte(maker.secretKey))
return token, payload, err
}
// VerifyToken checks if the token is valid or not
func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) {
keyFunc := func(token *tk.Token) (interface{}, error) {
_, ok := token.Method.(*tk.SigningMethodHMAC)
if !ok {
return nil, ErrInvalidToken
}
return []byte(maker.secretKey), nil
}
jwtToken, err := tk.ParseWithClaims(token, &Payload{}, keyFunc)
if err != nil {
return nil, ErrInvalidToken
}
payload, ok := jwtToken.Claims.(*Payload)
if !ok {
return nil, ErrInvalidToken
}
return payload, nil
}

View File

@@ -0,0 +1,24 @@
package token
import (
"time"
)
var engine Maker
// Maker 管理token的接口定义
type Maker interface {
// CreateToken 根据用户名和时间创建一个新的token
CreateToken(id string, username string, duration time.Duration) (string, *Payload, error)
// VerifyToken 校验token是否正确
VerifyToken(token string) (*Payload, error)
}
func CreateToken(id string, username string, duration time.Duration) (string, *Payload, error) {
return engine.CreateToken(id, username, duration)
}
func VerifyToken(token string) (*Payload, error) {
return engine.VerifyToken(token)
}

View File

@@ -0,0 +1,60 @@
package token
import (
"fmt"
"time"
"management/internal/config"
"github.com/aead/chacha20poly1305"
"github.com/o1egl/paseto"
)
// PasetoMaker is a PASETO token maker
type PasetoMaker struct {
paseto *paseto.V2
symmetricKey []byte
}
// NewPasetoMaker creates a new PasetoMaker
func NewPasetoMaker() error {
symmetricKey := config.File.JWT.SigningKey
if len(symmetricKey) != chacha20poly1305.KeySize {
return fmt.Errorf("invalid key size: must be exactly %d characters", chacha20poly1305.KeySize)
}
engine = &PasetoMaker{
paseto: paseto.NewV2(),
symmetricKey: []byte(symmetricKey),
}
return nil
}
// CreateToken creates a new token for a specific username and duration
func (maker *PasetoMaker) CreateToken(id string, username string, duration time.Duration) (string, *Payload, error) {
payload, err := NewPayload(id, username, duration)
if err != nil {
return "", payload, err
}
token, err := maker.paseto.Encrypt(maker.symmetricKey, payload, nil)
return token, payload, err
}
// VerifyToken checks if the token is valid or not
func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) {
payload := &Payload{}
err := maker.paseto.Decrypt(token, maker.symmetricKey, payload, nil)
if err != nil {
return nil, ErrInvalidToken
}
err = payload.Valid()
if err != nil {
return nil, err
}
return payload, nil
}

View File

@@ -0,0 +1,43 @@
package token
import (
"errors"
"time"
tk "github.com/golang-jwt/jwt/v5"
)
// Different types of error returned by the VerifyToken function
var (
ErrInvalidToken = errors.New("token is invalid")
ErrExpiredToken = errors.New("token has expired")
)
// Payload contains the payload data of the token
type Payload struct {
ID string `json:"id"`
Username string `json:"username"`
tk.RegisteredClaims // v5版本新加的方法
}
// NewPayload creates a new token payload with a specific username and duration
func NewPayload(id string, username string, duration time.Duration) (*Payload, error) {
payload := &Payload{
id,
username,
tk.RegisteredClaims{
ExpiresAt: tk.NewNumericDate(time.Now().Add(duration)), // 过期时间24小时
IssuedAt: tk.NewNumericDate(time.Now()), // 签发时间
NotBefore: tk.NewNumericDate(time.Now()), // 生效时间
},
}
return payload, nil
}
// Valid checks if the token payload is valid or not
func (payload *Payload) Valid() error {
if time.Now().After(payload.ExpiresAt.Time) {
return ErrExpiredToken
}
return nil
}

View File

@@ -0,0 +1,78 @@
package validation
import (
"errors"
"regexp"
)
// 使用预编译的全局正则表达式,避免重复创建和编译.
var (
lengthRegex = regexp.MustCompile(`^.{3,20}$`) // 长度在 3 到 20 个字符之间
validRegex = regexp.MustCompile(`^[A-Za-z0-9_]+$`) // 仅包含字母、数字和下划线
letterRegex = regexp.MustCompile(`[A-Za-z]`) // 至少包含一个字母
numberRegex = regexp.MustCompile(`\d`) // 至少包含一个数字
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) // 邮箱格式
phoneRegex = regexp.MustCompile(`^1[3-9]\d{9}$`) // 中国手机号
)
// IsValidUsername 校验用户名是否合法.
func IsValidUsername(username string) bool {
// 校验长度
if !lengthRegex.MatchString(username) {
return false
}
// 校验字符合法性
if !validRegex.MatchString(username) {
return false
}
return true
}
// IsValidPassword 判断密码是否符合复杂度要求.
func IsValidPassword(password string) error {
switch {
// 检查新密码是否为空
case password == "":
return errors.New("password cannot be empty")
// 检查新密码的长度要求
case len(password) < 6:
return errors.New("password must be at least 6 characters long")
// 使用正则表达式检查是否至少包含一个字母
case !letterRegex.MatchString(password):
return errors.New("password must contain at least one letter")
// 使用正则表达式检查是否至少包含一个数字
case !numberRegex.MatchString(password):
return errors.New("password must contain at least one number")
}
return nil
}
// IsValidEmail 判断电子邮件是否合法.
func IsValidEmail(email string) error {
// 检查电子邮件地址格式
if email == "" {
return errors.New("email cannot be empty")
}
// 使用正则表达式校验电子邮件格式
if !emailRegex.MatchString(email) {
return errors.New("invalid email format")
}
return nil
}
// IsValidPhone 判断手机号码是否合法.
func IsValidPhone(phone string) error {
// 检查手机号码格式
if phone == "" {
return errors.New("phone cannot be empty")
}
// 使用正则表达式校验手机号码格式假设是中国手机号11位数字
if !phoneRegex.MatchString(phone) {
return errors.New("invalid phone format")
}
return nil
}

View File

@@ -0,0 +1,86 @@
package validation
import (
"errors"
"fmt"
"reflect"
"strings"
"time"
"github.com/go-playground/validator/v10"
)
// 初始化一个验证器实例
var validate = validator.New()
// 自定义验证规则requiredint
func init() {
validate.RegisterValidation("telephone", func(fl validator.FieldLevel) bool {
if err := IsValidPhone(fl.Field().String()); err != nil {
return false
}
return true
})
validate.RegisterValidation("dateonly", func(fl validator.FieldLevel) bool {
_, err := time.ParseInLocation("2006-01-02", fl.Field().String(), time.Local)
if err != nil {
return false
}
return true
})
}
func ValidateForm(s any) error {
// 验证结构体数据
if err := validate.Struct(s); err != nil {
if _, ok := err.(*validator.InvalidValidationError); ok {
return errors.New("验证器配置错误")
}
// 获取结构体的反射类型
t := reflect.TypeOf(s)
errorMessages := make([]string, 0)
for _, err := range err.(validator.ValidationErrors) {
// 获取字段名
fieldName := err.Field()
// 获取标签名
tag := err.Tag()
// 获取验证失败的字段值
// value := err.Value()
// 通过反射获取字段
field, found := t.Elem().FieldByName(fieldName)
errorMsg := fmt.Sprintf("[%s] 验证失败: %s", fieldName, translate(tag))
if found {
commentTag := field.Tag.Get("comment")
errorMsg = fmt.Sprintf("[%s] 验证失败: %s", commentTag, translate(tag))
}
errorMessages = append(errorMessages, errorMsg)
}
return errors.New(strings.Join(errorMessages, "; "))
}
return nil
}
func translate(s string) string {
switch s {
case "required":
return "不能为空"
case "min":
return "不能小于"
case "max":
return "不能大于"
case "email":
return "不是有效的邮箱地址"
case "telephone":
return "不是有效的手机号"
case "datetime":
return "不是有效的日期时间"
case "dateonly":
return "不是有效的日期"
}
return s
}