first commit
This commit is contained in:
228
internal/pkg/convertor/convertor.go
Normal file
228
internal/pkg/convertor/convertor.go
Normal file
@@ -0,0 +1,228 @@
|
||||
// Package convertor 实现了一些函数来转换数据
|
||||
package convertor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ToBool 将字符串转换为布尔值
|
||||
func ToBool(s string) (bool, error) {
|
||||
return strconv.ParseBool(s)
|
||||
}
|
||||
|
||||
// ToBytes 将接口转换为字节数
|
||||
func ToBytes(value any) ([]byte, error) {
|
||||
v := reflect.ValueOf(value)
|
||||
|
||||
switch value.(type) {
|
||||
case int, int8, int16, int32, int64:
|
||||
number := v.Int()
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
buf.Reset()
|
||||
err := binary.Write(buf, binary.BigEndian, number)
|
||||
return buf.Bytes(), err
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
number := v.Uint()
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
buf.Reset()
|
||||
err := binary.Write(buf, binary.BigEndian, number)
|
||||
return buf.Bytes(), err
|
||||
case float32:
|
||||
number := float32(v.Float())
|
||||
bits := math.Float32bits(number)
|
||||
bys := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(bys, bits)
|
||||
return bys, nil
|
||||
case float64:
|
||||
number := v.Float()
|
||||
bits := math.Float64bits(number)
|
||||
bys := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(bys, bits)
|
||||
return bys, nil
|
||||
case bool:
|
||||
return strconv.AppendBool([]byte{}, v.Bool()), nil
|
||||
case string:
|
||||
return []byte(v.String()), nil
|
||||
case []byte:
|
||||
return v.Bytes(), nil
|
||||
default:
|
||||
newValue, err := json.Marshal(value)
|
||||
return newValue, err
|
||||
}
|
||||
}
|
||||
|
||||
// ToChar 将字符串转换为char slice
|
||||
func ToChar(s string) []string {
|
||||
c := make([]string, 0)
|
||||
if len(s) == 0 {
|
||||
c = append(c, "")
|
||||
}
|
||||
for _, v := range s {
|
||||
c = append(c, string(v))
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// ToString 将值转换为字符串
|
||||
func ToString(value any) string {
|
||||
res := ""
|
||||
if value == nil {
|
||||
return res
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(value)
|
||||
|
||||
switch value.(type) {
|
||||
case float32, float64:
|
||||
res = strconv.FormatFloat(v.Float(), 'f', -1, 64)
|
||||
return res
|
||||
case int, int8, int16, int32, int64:
|
||||
res = strconv.FormatInt(v.Int(), 10)
|
||||
return res
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
res = strconv.FormatUint(v.Uint(), 10)
|
||||
return res
|
||||
case string:
|
||||
res = v.String()
|
||||
return res
|
||||
case []byte:
|
||||
res = string(v.Bytes())
|
||||
return res
|
||||
default:
|
||||
newValue, _ := json.Marshal(value)
|
||||
res = string(newValue)
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
// ToJson 将值转换为有效的json字符串
|
||||
func ToJson(value any) (string, error) {
|
||||
res, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(res), nil
|
||||
}
|
||||
|
||||
// ToFloat 将数值转换为float64,如果输入的不是float,则返回0.0和错误
|
||||
func ToFloat(value any) (float64, error) {
|
||||
v := reflect.ValueOf(value)
|
||||
|
||||
res := 0.0
|
||||
err := fmt.Errorf("ToInt: unvalid interface type %T", value)
|
||||
switch value.(type) {
|
||||
case int, int8, int16, int32, int64:
|
||||
res = float64(v.Int())
|
||||
return res, nil
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
res = float64(v.Uint())
|
||||
return res, nil
|
||||
case float32, float64:
|
||||
res = v.Float()
|
||||
return res, nil
|
||||
case string:
|
||||
res, err = strconv.ParseFloat(v.String(), 64)
|
||||
if err != nil {
|
||||
res = 0.0
|
||||
}
|
||||
return res, err
|
||||
default:
|
||||
return res, err
|
||||
}
|
||||
}
|
||||
|
||||
// ToInt 将数值转换为int64,如果输入的不是数字格式,则返回0和错误
|
||||
func ToInt(value any) (int64, error) {
|
||||
v := reflect.ValueOf(value)
|
||||
|
||||
var res int64
|
||||
err := fmt.Errorf("ToInt: invalid interface type %T", value)
|
||||
switch value.(type) {
|
||||
case int, int8, int16, int32, int64:
|
||||
res = v.Int()
|
||||
return res, nil
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
res = int64(v.Uint())
|
||||
return res, nil
|
||||
case float32, float64:
|
||||
res = int64(v.Float())
|
||||
return res, nil
|
||||
case string:
|
||||
res, err = strconv.ParseInt(v.String(), 0, 64)
|
||||
if err != nil {
|
||||
res = 0
|
||||
}
|
||||
return res, err
|
||||
default:
|
||||
return res, err
|
||||
}
|
||||
}
|
||||
|
||||
// StructToMap 将结构体转换为Map,只转换导出的结构体字段
|
||||
// Map key的指定与结构字段标签`json`值相同
|
||||
func StructToMap(value any) (map[string]any, error) {
|
||||
v := reflect.ValueOf(value)
|
||||
t := reflect.TypeOf(value)
|
||||
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("data type %T not support, shuld be struct or pointer to struct", value)
|
||||
}
|
||||
|
||||
res := make(map[string]any)
|
||||
|
||||
fieldNum := t.NumField()
|
||||
pattern := `^[A-Z]`
|
||||
regex := regexp.MustCompile(pattern)
|
||||
for i := 0; i < fieldNum; i++ {
|
||||
name := t.Field(i).Name
|
||||
tag := t.Field(i).Tag.Get("json")
|
||||
if regex.MatchString(name) && tag != "" {
|
||||
// res[name] = v.Field(i).Interface()
|
||||
res[tag] = v.Field(i).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// ColorHexToRGB 将十六进制颜色转换为RGB颜色
|
||||
func ColorHexToRGB(colorHex string) (red, green, blue int) {
|
||||
colorHex = strings.TrimPrefix(colorHex, "#")
|
||||
color64, err := strconv.ParseInt(colorHex, 16, 32)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
color := int(color64)
|
||||
return color >> 16, (color & 0x00FF00) >> 8, color & 0x0000FF
|
||||
}
|
||||
|
||||
// ColorRGBToHex 将RGB颜色转换为十六进制颜色
|
||||
func ColorRGBToHex(red, green, blue int) string {
|
||||
r := strconv.FormatInt(int64(red), 16)
|
||||
g := strconv.FormatInt(int64(green), 16)
|
||||
b := strconv.FormatInt(int64(blue), 16)
|
||||
|
||||
if len(r) == 1 {
|
||||
r = "0" + r
|
||||
}
|
||||
if len(g) == 1 {
|
||||
g = "0" + g
|
||||
}
|
||||
if len(b) == 1 {
|
||||
b = "0" + b
|
||||
}
|
||||
|
||||
return "#" + r + g + b
|
||||
}
|
||||
11
internal/pkg/convertor/http_form.go
Normal file
11
internal/pkg/convertor/http_form.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package convertor
|
||||
|
||||
import "strconv"
|
||||
|
||||
func ConvertInt[T int | int16 | int32 | int64](value string, defaultValue T) T {
|
||||
i, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return T(i)
|
||||
}
|
||||
18
internal/pkg/convertor/pgx.go
Normal file
18
internal/pkg/convertor/pgx.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package convertor
|
||||
|
||||
import "github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
func NumericToFloat64(num pgtype.Numeric) float64 {
|
||||
if !num.Valid {
|
||||
return 0
|
||||
}
|
||||
f1, err := num.Float64Value()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
f2, err := f1.Value()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return f2.(float64)
|
||||
}
|
||||
72
internal/pkg/crypto/password.go
Normal file
72
internal/pkg/crypto/password.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
// ******************** scrypt ********************
|
||||
|
||||
// ScryptHashPassword scrypt 加密
|
||||
// password 原始密码
|
||||
func ScryptHashPassword(password string) (string, error) {
|
||||
// example for making salt - https://play.golang.org/p/_Aw6WeWC42I
|
||||
salt := make([]byte, 32)
|
||||
_, err := rand.Read(salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// using recommended cost parameters from - https://godoc.org/golang.org/x/crypto/scrypt
|
||||
hash, err := scrypt.Key([]byte(password), salt, 32768, 8, 1, 32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// return hex-encoded string with salt appended to password
|
||||
hashedPW := fmt.Sprintf("%s.%s", hex.EncodeToString(hash), hex.EncodeToString(salt))
|
||||
return hashedPW, nil
|
||||
}
|
||||
|
||||
// ScryptComparePassword 判断密码是否正确
|
||||
// storedPassword 加密密码
|
||||
// suppliedPassword 原始密码
|
||||
func ScryptComparePassword(storedPassword string, suppliedPassword string) (bool, error) {
|
||||
pwdSalt := strings.Split(storedPassword, ".")
|
||||
|
||||
// check supplied password salted with hash
|
||||
salt, err := hex.DecodeString(pwdSalt[1])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("unable to verify user password")
|
||||
}
|
||||
|
||||
hash, err := scrypt.Key([]byte(suppliedPassword), salt, 32768, 8, 1, 32)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return hex.EncodeToString(hash) == pwdSalt[0], nil
|
||||
}
|
||||
|
||||
// ******************** bcrypt ********************
|
||||
|
||||
// BcryptHashPassword bcrypt 加密
|
||||
// password 原始密码
|
||||
func BcryptHashPassword(password string) ([]byte, error) {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
return hashedPassword, nil
|
||||
}
|
||||
|
||||
// BcryptComparePassword 判断密码是否正确
|
||||
// hashedPassword 加密密码
|
||||
// password 原始密码
|
||||
func BcryptComparePassword(hashedPassword []byte, password string) error {
|
||||
return bcrypt.CompareHashAndPassword(hashedPassword, []byte(password))
|
||||
}
|
||||
233
internal/pkg/fetcher/fetcher.go
Normal file
233
internal/pkg/fetcher/fetcher.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package fetcher
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Get get请求
|
||||
func Get(url string, timeout time.Duration) ([]byte, int, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("[get] new request err: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("[get] client do err: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
status := resp.StatusCode
|
||||
res, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, status, fmt.Errorf("read response err: %v", err)
|
||||
}
|
||||
|
||||
return res, status, nil
|
||||
}
|
||||
|
||||
// GetString 请求
|
||||
func GetString(url string, parameter map[string]string, timeout time.Duration) ([]byte, int, error) {
|
||||
byteParameter := new(bytes.Buffer)
|
||||
w := multipart.NewWriter(byteParameter)
|
||||
for k, v := range parameter {
|
||||
w.WriteField(k, v)
|
||||
}
|
||||
w.Close()
|
||||
|
||||
request, err := http.NewRequest("GET", url, byteParameter)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("[getstring] new request err: %v", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", w.FormDataContentType())
|
||||
|
||||
client := &http.Client{Timeout: time.Second * timeout}
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(response.Body)
|
||||
|
||||
status := response.StatusCode
|
||||
resp, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, status, err
|
||||
}
|
||||
|
||||
return resp, status, nil
|
||||
}
|
||||
|
||||
// GetJson application/json get 请求
|
||||
func GetJson(url string, parameter []byte, timeout time.Duration) ([]byte, int, error) {
|
||||
byteParameter := bytes.NewBuffer(parameter)
|
||||
req, err := http.NewRequest("GET", url, byteParameter)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("[getjson] new request err: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("[getjson] client do err: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
status := resp.StatusCode
|
||||
res, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, status, fmt.Errorf("read response err: %v", err)
|
||||
}
|
||||
|
||||
return res, status, nil
|
||||
}
|
||||
|
||||
// GetJsonWithToken Get json with token 请求
|
||||
func GetJsonWithToken(url string, token string, timeout time.Duration) ([]byte, int, error) {
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
request.Header.Set("Content-type", "application/json")
|
||||
request.Header.Set("Access-Token", token)
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, errors.New("请求网络错误")
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
status := response.StatusCode
|
||||
resp, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, status, err
|
||||
}
|
||||
|
||||
return resp, status, nil
|
||||
}
|
||||
|
||||
// PostJson application/json post 请求
|
||||
func PostJson(url string, parameter []byte, timeout time.Duration) ([]byte, int, error) {
|
||||
byteParameter := bytes.NewBuffer(parameter)
|
||||
request, err := http.NewRequest("POST", url, byteParameter)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("[postjson] new request err: %v", err)
|
||||
}
|
||||
request.Header.Set("Content-type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("[postjson] client do err: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
status := response.StatusCode
|
||||
all, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, status, fmt.Errorf("read response err: %v", err)
|
||||
}
|
||||
|
||||
return all, status, nil
|
||||
}
|
||||
|
||||
// PostString application/x-www-form-urlencoded 请求
|
||||
func PostString(url string, parameter []byte, timeout time.Duration) ([]byte, int, error) {
|
||||
byteParameter := bytes.NewBuffer(parameter)
|
||||
request, err := http.NewRequest("POST", url, byteParameter)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("[poststring] new request err: %v", err)
|
||||
}
|
||||
request.Header.Set("Content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("[poststring] client do err: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
status := response.StatusCode
|
||||
all, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, status, fmt.Errorf("read response err: %v", err)
|
||||
}
|
||||
|
||||
return all, status, nil
|
||||
}
|
||||
|
||||
// PostJsonWithToken Post json with token 请求
|
||||
func PostJsonWithToken(url string, parameter []byte, token string, timeout time.Duration) ([]byte, int, error) {
|
||||
bufParameter := bytes.NewBuffer(parameter)
|
||||
request, err := http.NewRequest("POST", url, bufParameter)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
request.Header.Set("Content-type", "application/json")
|
||||
request.Header.Set("Access-Token", token)
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("请求网络错误: %v", err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(response.Body)
|
||||
|
||||
stauts := response.StatusCode
|
||||
resp, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, stauts, errors.New("网络请求结果读取失败")
|
||||
}
|
||||
|
||||
return resp, stauts, nil
|
||||
}
|
||||
|
||||
// PostJsonWithBearerToken Post json with bearer token 请求
|
||||
func PostJsonWithBearerToken(url string, parameter []byte, token string, timeout time.Duration) ([]byte, int, error) {
|
||||
bufParameter := bytes.NewBuffer(parameter)
|
||||
request, err := http.NewRequest("POST", url, bufParameter)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
request.Header.Set("Content-type", "application/json")
|
||||
request.Header.Set("Authorization", fmt.Sprintf("bearer %s", token))
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("请求网络错误: %v", err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(response.Body)
|
||||
|
||||
status := response.StatusCode
|
||||
resp, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, status, errors.New("网络请求结果读取失败")
|
||||
}
|
||||
|
||||
return resp, status, nil
|
||||
}
|
||||
|
||||
// func determiningEncoding(r *bufio.Reader) encoding.Encoding {
|
||||
// b, err := r.Peek(1024)
|
||||
// if err != nil && err != io.EOF {
|
||||
// log.Printf("Fetcher error: %v", err)
|
||||
// return unicode.UTF8
|
||||
// }
|
||||
|
||||
// e, _, _ := charset.DetermineEncoding(b, "")
|
||||
// return e
|
||||
// }
|
||||
20
internal/pkg/file/file.go
Normal file
20
internal/pkg/file/file.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package file
|
||||
|
||||
import "os"
|
||||
|
||||
func Exists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
func Mkdir(path string) error {
|
||||
return os.MkdirAll(path, os.ModePerm)
|
||||
}
|
||||
159
internal/pkg/file/upload.go
Normal file
159
internal/pkg/file/upload.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
gonanoid "github.com/matoous/go-nanoid/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxImageSize = 10 << 20 // 10 MB
|
||||
MaxFileSize = 50 << 20 // 50 MB
|
||||
)
|
||||
|
||||
var ErrUnsupported = errors.New("文件格式不支持")
|
||||
|
||||
type FileType int
|
||||
|
||||
const (
|
||||
ALL FileType = 0
|
||||
IMG FileType = 1
|
||||
)
|
||||
|
||||
func UploadFilename(filepath string, t FileType) (string, error) {
|
||||
fileOpen, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer fileOpen.Close()
|
||||
|
||||
fileBytes, err := io.ReadAll(fileOpen)
|
||||
if err != nil {
|
||||
return "", errors.New("failed to read file")
|
||||
}
|
||||
|
||||
if t == IMG {
|
||||
// 判断是不是图片
|
||||
if !filetype.IsImage(fileBytes) {
|
||||
return "", ErrUnsupported
|
||||
}
|
||||
}
|
||||
|
||||
kind, err := filetype.Match(fileBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if kind == filetype.Unknown {
|
||||
return "", ErrUnsupported
|
||||
}
|
||||
|
||||
// 使用 filetype 判断类型后已经去读了一些bytes了
|
||||
// 要恢复文件读取位置
|
||||
_, err = fileOpen.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
dir := GetPath()
|
||||
exist, _ := Exists(dir)
|
||||
if !exist {
|
||||
if err := Mkdir(dir); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
filename := GenFilename(kind.Extension)
|
||||
path := path.Join(dir, filename)
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(f, fileOpen)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "/" + path, nil
|
||||
}
|
||||
|
||||
func UploadFile(file *multipart.FileHeader, t FileType) (string, error) {
|
||||
if file.Size > MaxFileSize {
|
||||
return "", errors.New("failed to receive file too large")
|
||||
}
|
||||
|
||||
fileOpen, err := file.Open()
|
||||
if err != nil {
|
||||
return "", errors.New("fialed to open file")
|
||||
}
|
||||
defer fileOpen.Close()
|
||||
|
||||
fileBytes, err := io.ReadAll(fileOpen)
|
||||
if err != nil {
|
||||
return "", errors.New("failed to read file")
|
||||
}
|
||||
|
||||
if t == IMG {
|
||||
// 判断是不是图片
|
||||
if !filetype.IsImage(fileBytes) {
|
||||
return "", ErrUnsupported
|
||||
}
|
||||
}
|
||||
|
||||
kind, err := filetype.Match(fileBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if kind == filetype.Unknown {
|
||||
return "", ErrUnsupported
|
||||
}
|
||||
|
||||
// 使用 filetype 判断类型后已经去读了一些bytes了
|
||||
// 要恢复文件读取位置
|
||||
_, err = fileOpen.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
dir := GetPath()
|
||||
exist, _ := Exists(dir)
|
||||
if !exist {
|
||||
if err := Mkdir(dir); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
filename := GenFilename(kind.Extension)
|
||||
path := path.Join(dir, filename)
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(f, fileOpen)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "/" + path, nil
|
||||
}
|
||||
|
||||
func GetPath() string {
|
||||
return fmt.Sprintf("upload/%s/%s/%s/", time.Now().Format("2006"), time.Now().Format("01"), time.Now().Format("02"))
|
||||
}
|
||||
|
||||
func GenFilename(ext string) string {
|
||||
id, _ := gonanoid.New()
|
||||
return fmt.Sprintf("%s.%s", id, ext)
|
||||
}
|
||||
26
internal/pkg/gin/gu/cors.go
Normal file
26
internal/pkg/gin/gu/cors.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package gu
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Cors() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
method := c.Request.Method
|
||||
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, AccessToken, X-CSRF-Token, Authorization, Token")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// 放行所有OPTIONS方法
|
||||
if method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
}
|
||||
// 处理请求
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
40
internal/pkg/gin/gu/response.go
Normal file
40
internal/pkg/gin/gu/response.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package gu
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
||||
type PageData struct {
|
||||
Total int64 `json:"total"`
|
||||
PageID int32 `json:"page_id"`
|
||||
PageSize int32 `json:"page_size"`
|
||||
Result any `json:"result"`
|
||||
}
|
||||
|
||||
func Ok(ctx *gin.Context, data any) {
|
||||
ResponseJson(ctx, http.StatusOK, "ok", data)
|
||||
}
|
||||
|
||||
func Failed(ctx *gin.Context, message string) {
|
||||
ResponseJson(ctx, http.StatusInternalServerError, message, nil)
|
||||
}
|
||||
|
||||
func FailedWithCode(ctx *gin.Context, code int, message string) {
|
||||
ResponseJson(ctx, code, message, nil)
|
||||
}
|
||||
|
||||
func ResponseJson(ctx *gin.Context, code int, message string, data any) {
|
||||
ctx.JSON(code, response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
57
internal/pkg/gin/gu/validator.go
Normal file
57
internal/pkg/gin/gu/validator.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package gu
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/go-playground/locales/en"
|
||||
"github.com/go-playground/locales/zh"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
enTranslations "github.com/go-playground/validator/v10/translations/en"
|
||||
chTranslations "github.com/go-playground/validator/v10/translations/zh"
|
||||
)
|
||||
|
||||
var trans ut.Translator
|
||||
|
||||
// loca 通常取决于 http 请求头的 'Accept-Language'
|
||||
func SetValidatorTrans(local string) (err error) {
|
||||
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
||||
zhT := zh.New() // chinese
|
||||
enT := en.New() // english
|
||||
uni := ut.New(enT, zhT, enT)
|
||||
|
||||
var o bool
|
||||
trans, o = uni.GetTranslator(local)
|
||||
if !o {
|
||||
return fmt.Errorf("uni.GetTranslator(%s) failed", local)
|
||||
}
|
||||
// register translate
|
||||
// 注册翻译器
|
||||
switch local {
|
||||
case "en":
|
||||
err = enTranslations.RegisterDefaultTranslations(v, trans)
|
||||
case "zh":
|
||||
err = chTranslations.RegisterDefaultTranslations(v, trans)
|
||||
default:
|
||||
err = enTranslations.RegisterDefaultTranslations(v, trans)
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ValidatorErrors(ctx *gin.Context, err error) {
|
||||
if errors, ok := err.(validator.ValidationErrors); ok {
|
||||
errs := gin.H{}
|
||||
for _, e := range errors {
|
||||
errs[e.StructField()] = strings.Replace(e.Translate(trans), e.StructField(), "", -1)
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, errs)
|
||||
} else {
|
||||
ctx.JSON(http.StatusBadRequest, err)
|
||||
}
|
||||
}
|
||||
33
internal/pkg/logger/log.go
Normal file
33
internal/pkg/logger/log.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"management/internal/config"
|
||||
|
||||
"github.com/natefinch/lumberjack"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Init() {
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
logRotate := &lumberjack.Logger{
|
||||
Filename: "./log/run.log", // 日志文件的位置
|
||||
MaxSize: 10, // 在进行切割之前,日志文件的最大大小(以MB为单位)
|
||||
MaxBackups: 100, // 保留旧文件的最大个数
|
||||
MaxAge: 30, // 保留旧文件的最大天数
|
||||
Compress: true,
|
||||
}
|
||||
zerolog.TimeFieldFormat = time.DateTime
|
||||
log.Logger = log.With().Caller().Logger()
|
||||
|
||||
if config.File.App.Prod {
|
||||
log.Logger = log.Output(logRotate)
|
||||
} else {
|
||||
consoleWriter := zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.DateTime}
|
||||
multi := zerolog.MultiLevelWriter(consoleWriter, logRotate)
|
||||
log.Logger = log.Output(multi)
|
||||
}
|
||||
}
|
||||
53
internal/pkg/rand/rand.go
Normal file
53
internal/pkg/rand/rand.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package rand
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
randv2 "math/rand/v2"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
func Bytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
nRead, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bytes: %w", err)
|
||||
}
|
||||
if nRead < n {
|
||||
return nil, fmt.Errorf("bytes: didn't read enough random bytes")
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func String(n int) (string, error) {
|
||||
b, err := Bytes(n)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("string: %w", err)
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func RandomInt(n int) string {
|
||||
letters := []byte("0123456789")
|
||||
l := len(letters)
|
||||
result := make([]byte, n)
|
||||
for i := range result {
|
||||
result[i] = letters[randv2.IntN(l)]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomString(n int) string {
|
||||
var sb strings.Builder
|
||||
k := len(alphabet)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
c := alphabet[randv2.IntN(k)]
|
||||
sb.WriteByte(c)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
126
internal/pkg/redis/redis.go
Normal file
126
internal/pkg/redis/redis.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"management/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
engine *redis.Client
|
||||
ErrRedisKeyNotFound = errors.New("redis key not found")
|
||||
)
|
||||
|
||||
// func GetRedis() *redis.Client {
|
||||
// return rd
|
||||
// }
|
||||
|
||||
func Init() error {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", config.File.Redis.Host, config.File.Redis.Port),
|
||||
Password: config.File.Redis.Password,
|
||||
DB: config.File.Redis.DB,
|
||||
})
|
||||
_, err := rdb.Ping(context.Background()).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
engine = rdb
|
||||
return nil
|
||||
}
|
||||
|
||||
func Encode(a any) ([]byte, error) {
|
||||
var b bytes.Buffer
|
||||
if err := gob.NewEncoder(&b).Encode(a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// Set 设置值
|
||||
func Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
|
||||
return engine.Set(ctx, key, value, expiration).Err()
|
||||
}
|
||||
|
||||
// Del 删除键值
|
||||
func Del(ctx context.Context, keys ...string) error {
|
||||
return engine.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
// Get 获取值
|
||||
func Get(ctx context.Context, key string) (string, error) {
|
||||
val, err := engine.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return "", ErrRedisKeyNotFound
|
||||
} else if err != nil {
|
||||
return "", fmt.Errorf("cannot get value with:[%s]: %v", key, err)
|
||||
} else {
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetBytes 获取值
|
||||
func GetBytes(ctx context.Context, key string) ([]byte, error) {
|
||||
val, err := engine.Get(ctx, key).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, ErrRedisKeyNotFound
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("cannot get value with:[%s]: %v", key, err)
|
||||
} else {
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
|
||||
func Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd {
|
||||
return engine.Scan(ctx, cursor, match, count)
|
||||
}
|
||||
|
||||
func Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
return engine.Keys(ctx, pattern).Result()
|
||||
}
|
||||
|
||||
func ListKeys(ctx context.Context, pattern string, pageID int, pageSize int) ([]string, int, error) {
|
||||
all, err := engine.Keys(ctx, pattern).Result()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
count := len(all)
|
||||
if count == 0 {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 使用SCAN命令分页获取键
|
||||
cursor := uint64(0)
|
||||
var keys []string
|
||||
for {
|
||||
var scanResult []string
|
||||
var err error
|
||||
scanResult, cursor, err = engine.Scan(ctx, cursor, pattern, int64(pageSize)).Result()
|
||||
if err != nil {
|
||||
return nil, count, err
|
||||
}
|
||||
keys = append(keys, scanResult...)
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
startIndex := (pageID - 1) * pageSize
|
||||
endIndex := startIndex + pageSize
|
||||
if startIndex >= len(keys) {
|
||||
return nil, count, nil
|
||||
}
|
||||
if endIndex > len(keys) {
|
||||
endIndex = len(keys)
|
||||
}
|
||||
return keys[startIndex:endIndex], count, nil
|
||||
}
|
||||
83
internal/pkg/session/redis.go
Normal file
83
internal/pkg/session/redis.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"management/internal/pkg/redis"
|
||||
)
|
||||
|
||||
var (
|
||||
storePrefix = "scs:session:"
|
||||
ctx = context.Background()
|
||||
DefaultRedisStore = newRedisStore()
|
||||
)
|
||||
|
||||
type redisStore struct{}
|
||||
|
||||
func newRedisStore() *redisStore {
|
||||
return &redisStore{}
|
||||
}
|
||||
|
||||
// Delete should remove the session token and corresponding data from the
|
||||
// session store. If the token does not exist then Delete should be a no-op
|
||||
// and return nil (not an error).
|
||||
func (s *redisStore) Delete(token string) error {
|
||||
return redis.Del(ctx, storePrefix+token)
|
||||
}
|
||||
|
||||
// Find should return the data for a session token from the store. If the
|
||||
// session token is not found or is expired, the found return value should
|
||||
// be false (and the err return value should be nil). Similarly, tampered
|
||||
// or malformed tokens should result in a found return value of false and a
|
||||
// nil err value. The err return value should be used for system errors only.
|
||||
func (s *redisStore) Find(token string) (b []byte, found bool, err error) {
|
||||
val, err := redis.GetBytes(ctx, storePrefix+token)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
} else {
|
||||
return val, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Commit should add the session token and data to the store, with the given
|
||||
// expiry time. If the session token already exists, then the data and
|
||||
// expiry time should be overwritten.
|
||||
func (s *redisStore) Commit(token string, b []byte, expiry time.Time) error {
|
||||
// TODO: 这边可以调整时间
|
||||
exp, err := time.ParseInLocation(time.DateTime, time.Now().Format("2006-01-02")+" 23:59:59", time.Local)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
expired := exp.Sub(t)
|
||||
return redis.Set(ctx, storePrefix+token, b, expired)
|
||||
}
|
||||
|
||||
// All should return a map containing data for all active sessions (i.e.
|
||||
// sessions which have not expired). The map key should be the session
|
||||
// token and the map value should be the session data. If no active
|
||||
// sessions exist this should return an empty (not nil) map.
|
||||
func (s *redisStore) All() (map[string][]byte, error) {
|
||||
sessions := make(map[string][]byte)
|
||||
|
||||
iter := redis.Scan(ctx, 0, storePrefix+"*", 0).Iterator()
|
||||
for iter.Next(ctx) {
|
||||
key := iter.Val()
|
||||
token := key[len(storePrefix):]
|
||||
data, exists, err := s.Find(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if exists {
|
||||
sessions[token] = data
|
||||
}
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
59
internal/pkg/session/session.go
Normal file
59
internal/pkg/session/session.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"management/internal/config"
|
||||
db "management/internal/db/sqlc"
|
||||
|
||||
"github.com/alexedwards/scs/pgxstore"
|
||||
"github.com/alexedwards/scs/v2"
|
||||
)
|
||||
|
||||
var sessionManager *scs.SessionManager
|
||||
|
||||
func Init() {
|
||||
sessionManager = scs.New()
|
||||
sessionManager.Lifetime = 24 * time.Hour
|
||||
sessionManager.IdleTimeout = 2 * time.Hour
|
||||
sessionManager.Cookie.Name = "token"
|
||||
sessionManager.Cookie.HttpOnly = true
|
||||
sessionManager.Cookie.Persist = true
|
||||
sessionManager.Cookie.SameSite = http.SameSiteStrictMode
|
||||
sessionManager.Cookie.Secure = config.File.App.Prod
|
||||
|
||||
// postgres
|
||||
// github.com/alexedwards/scs/postgresstore
|
||||
// sessionManager.Store = postgresstore.New(db)
|
||||
// pgx
|
||||
// github.com/alexedwards/scs/pgxstore
|
||||
sessionManager.Store = pgxstore.New(db.Engine.Pool())
|
||||
// redis
|
||||
// sessionManager.Store = newRedisStore()
|
||||
}
|
||||
|
||||
func Destroy(ctx context.Context) error {
|
||||
return sessionManager.Destroy(ctx)
|
||||
}
|
||||
|
||||
func LoadAndSave(next http.Handler) http.Handler {
|
||||
return sessionManager.LoadAndSave(next)
|
||||
}
|
||||
|
||||
func Put(ctx context.Context, key string, val interface{}) {
|
||||
sessionManager.Put(ctx, key, val)
|
||||
}
|
||||
|
||||
func GetBytes(ctx context.Context, key string) []byte {
|
||||
return sessionManager.GetBytes(ctx, key)
|
||||
}
|
||||
|
||||
func Exists(ctx context.Context, key string) bool {
|
||||
return sessionManager.Exists(ctx, key)
|
||||
}
|
||||
|
||||
func RenewToken(ctx context.Context) error {
|
||||
return sessionManager.RenewToken(ctx)
|
||||
}
|
||||
41
internal/pkg/smb/smb.go
Normal file
41
internal/pkg/smb/smb.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package smb
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
|
||||
"management/internal/config"
|
||||
|
||||
"github.com/hirochachacha/go-smb2"
|
||||
)
|
||||
|
||||
var FS fs.FS
|
||||
|
||||
func Init() error {
|
||||
conn, err := net.Dial("tcp", config.File.Smb.Host+":445")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
d := &smb2.Dialer{
|
||||
Initiator: &smb2.NTLMInitiator{
|
||||
User: config.File.Smb.Name,
|
||||
Password: config.File.Smb.Pass,
|
||||
},
|
||||
}
|
||||
s, err := d.Dial(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Logoff()
|
||||
|
||||
fs, err := s.Mount(config.File.Smb.Mount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fs.Umount()
|
||||
|
||||
FS = fs.DirFS(".")
|
||||
return nil
|
||||
}
|
||||
20
internal/pkg/snowflake/snowflake.go
Normal file
20
internal/pkg/snowflake/snowflake.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package snowflake
|
||||
|
||||
import (
|
||||
"github.com/bwmarrin/snowflake"
|
||||
)
|
||||
|
||||
var node *snowflake.Node
|
||||
|
||||
func Init() error {
|
||||
n, err := snowflake.NewNode(1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node = n
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetId() int64 {
|
||||
return node.Generate().Int64()
|
||||
}
|
||||
25
internal/pkg/sqids/sqids.go
Normal file
25
internal/pkg/sqids/sqids.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package sqids
|
||||
|
||||
import "github.com/sqids/sqids-go"
|
||||
|
||||
var engine *sqids.Sqids
|
||||
|
||||
func Init() error {
|
||||
var err error
|
||||
engine, err = sqids.New(sqids.Options{
|
||||
MinLength: 12,
|
||||
Alphabet: "AvjM1lkB8N6cuhs2oFxnXyYDwCmLGI7JOzt9fW3HRgb5ZQrqaU04TePSVKdpiE",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Encode(ids []uint64) (string, error) {
|
||||
return engine.Encode(ids)
|
||||
}
|
||||
|
||||
func Decode(s string) []uint64 {
|
||||
return engine.Decode(s)
|
||||
}
|
||||
40
internal/pkg/strutil/internal.go
Normal file
40
internal/pkg/strutil/internal.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package strutil
|
||||
|
||||
import "strings"
|
||||
|
||||
// splitWordsToLower 将一个字符串按大写字母分割成若干个字符串
|
||||
func splitWordsToLower(s string) []string {
|
||||
var res []string
|
||||
|
||||
upperIndexes := upperIndex(s)
|
||||
l := len(upperIndexes)
|
||||
if upperIndexes == nil || l == 0 {
|
||||
if s != "" {
|
||||
res = append(res, s)
|
||||
}
|
||||
return res
|
||||
}
|
||||
for i := 0; i < l; i++ {
|
||||
if i < l-1 {
|
||||
res = append(res, strings.ToLower(s[upperIndexes[i]:upperIndexes[i+1]]))
|
||||
} else {
|
||||
res = append(res, strings.ToLower(s[upperIndexes[i]:]))
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// upperIndex 获得一个int slice,其元素是一个字符串的所有大写字母索引
|
||||
func upperIndex(s string) []int {
|
||||
var res []int
|
||||
for i := 0; i < len(s); i++ {
|
||||
if 64 < s[i] && s[i] < 91 {
|
||||
res = append(res, i)
|
||||
}
|
||||
}
|
||||
if len(s) > 0 && res != nil && res[0] != 0 {
|
||||
res = append([]int{0}, res...)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
296
internal/pkg/strutil/strutil.go
Normal file
296
internal/pkg/strutil/strutil.go
Normal file
@@ -0,0 +1,296 @@
|
||||
// Package strutil 实现了一些函数来操作字符串
|
||||
package strutil
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// CamelCase 转换字符串到驼峰法(CamelCase)
|
||||
func CamelCase(s string) string {
|
||||
if len(s) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
res := ""
|
||||
blankSpace := " "
|
||||
regex, _ := regexp.Compile("[-_&]+")
|
||||
ss := regex.ReplaceAllString(s, blankSpace)
|
||||
for i, v := range strings.Split(ss, blankSpace) {
|
||||
vv := []rune(v)
|
||||
if i == 0 {
|
||||
if vv[i] >= 65 && vv[i] <= 96 {
|
||||
vv[0] += 32
|
||||
}
|
||||
res += string(vv)
|
||||
} else {
|
||||
res += Capitalize(v)
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// Capitalize 将一个字符串的第一个字符转换为大写,其余的转换为小写
|
||||
func Capitalize(s string) string {
|
||||
if len(s) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
out := make([]rune, len(s))
|
||||
for i, v := range s {
|
||||
if i == 0 {
|
||||
out[i] = unicode.ToUpper(v)
|
||||
} else {
|
||||
out[i] = unicode.ToLower(v)
|
||||
}
|
||||
}
|
||||
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// UpperFirst 将字符串的第一个字符转换为大写
|
||||
func UpperFirst(s string) string {
|
||||
if len(s) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
r, size := utf8.DecodeRuneInString(s)
|
||||
r = unicode.ToUpper(r)
|
||||
|
||||
return string(r) + s[size:]
|
||||
}
|
||||
|
||||
// LowerFirst 将字符串的第一个字符转换为小写
|
||||
func LowerFirst(s string) string {
|
||||
if len(s) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
r, size := utf8.DecodeRuneInString(s)
|
||||
r = unicode.ToLower(r)
|
||||
|
||||
return string(r) + s[size:]
|
||||
}
|
||||
|
||||
// PadEnd 如果字符串比尺寸短,则将其垫在右侧
|
||||
// 填充字符如果超过大小,将被截断
|
||||
func PadEnd(source string, size int, padStr string) string {
|
||||
len1 := len(source)
|
||||
len2 := len(padStr)
|
||||
|
||||
if len1 >= size {
|
||||
return source
|
||||
}
|
||||
|
||||
fill := ""
|
||||
if len2 >= size-len1 {
|
||||
fill = padStr[0 : size-len1]
|
||||
} else {
|
||||
fill = strings.Repeat(padStr, size-len1)
|
||||
}
|
||||
return source + fill[0:size-len1]
|
||||
}
|
||||
|
||||
// PadStart 如果字符串比尺寸短,则将其垫在左侧
|
||||
// 填充字符如果超过大小,将被截断
|
||||
func PadStart(source string, size int, padStr string) string {
|
||||
len1 := len(source)
|
||||
len2 := len(padStr)
|
||||
|
||||
if len1 >= size {
|
||||
return source
|
||||
}
|
||||
|
||||
fill := ""
|
||||
if len2 >= size-len1 {
|
||||
fill = padStr[0 : size-len1]
|
||||
} else {
|
||||
fill = strings.Repeat(padStr, size-len1)
|
||||
}
|
||||
return fill[0:size-len1] + source
|
||||
}
|
||||
|
||||
// KebabCase 将字符串转为短横线隔开式(kebab-case)
|
||||
func KebabCase(s string) string {
|
||||
if len(s) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
regex := regexp.MustCompile(`[\W|_]+`)
|
||||
blankSpace := " "
|
||||
match := regex.ReplaceAllString(s, blankSpace)
|
||||
rs := strings.Split(match, blankSpace)
|
||||
|
||||
var res []string
|
||||
for _, v := range rs {
|
||||
splitWords := splitWordsToLower(v)
|
||||
if len(splitWords) > 0 {
|
||||
res = append(res, splitWords...)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(res, "-")
|
||||
}
|
||||
|
||||
// SnakeCase 将字符串转为蛇形命名(snake_case)
|
||||
func SnakeCase(s string) string {
|
||||
if len(s) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
regex := regexp.MustCompile(`[\W|_]+`)
|
||||
blankSpace := " "
|
||||
match := regex.ReplaceAllString(s, blankSpace)
|
||||
rs := strings.Split(match, blankSpace)
|
||||
|
||||
var res []string
|
||||
for _, v := range rs {
|
||||
splitWords := splitWordsToLower(v)
|
||||
if len(splitWords) > 0 {
|
||||
res = append(res, splitWords...)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(res, "_")
|
||||
}
|
||||
|
||||
// Before 在字符首次出现的位置之前,在源字符串中创建子串
|
||||
func Before(s, char string) string {
|
||||
if s == "" || char == "" {
|
||||
return s
|
||||
}
|
||||
i := strings.Index(s, char)
|
||||
return s[0:i]
|
||||
}
|
||||
|
||||
// BeforeLast 在字符最后出现的位置之前,在源字符串中创建子串
|
||||
func BeforeLast(s, char string) string {
|
||||
if s == "" || char == "" {
|
||||
return s
|
||||
}
|
||||
i := strings.LastIndex(s, char)
|
||||
return s[0:i]
|
||||
}
|
||||
|
||||
// After 在字符首次出现的位置后,在源字符串中创建子串
|
||||
func After(s, char string) string {
|
||||
if s == "" || char == "" {
|
||||
return s
|
||||
}
|
||||
i := strings.Index(s, char)
|
||||
return s[i+len(char):]
|
||||
}
|
||||
|
||||
// AfterLast 在字符最后出现的位置后,在源字符串中创建子串
|
||||
func AfterLast(s, char string) string {
|
||||
if s == "" || char == "" {
|
||||
return s
|
||||
}
|
||||
i := strings.LastIndex(s, char)
|
||||
return s[i+len(char):]
|
||||
}
|
||||
|
||||
// IsString 检查值的数据类型是否为字符串
|
||||
func IsString(v any) bool {
|
||||
if v == nil {
|
||||
return false
|
||||
}
|
||||
switch v.(type) {
|
||||
case string:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ReverseStr 返回字符顺序与给定字符串相反的字符串
|
||||
func ReverseStr(s string) string {
|
||||
r := []rune(s)
|
||||
for i, j := 0, len(r)-1; i < j; i, j = i+1, j-1 {
|
||||
r[i], r[j] = r[j], r[i]
|
||||
}
|
||||
return string(r)
|
||||
}
|
||||
|
||||
// Wrap 用另一个字符串包住一个字符串
|
||||
func Wrap(str string, wrapWith string) string {
|
||||
if str == "" || wrapWith == "" {
|
||||
return str
|
||||
}
|
||||
var sb strings.Builder
|
||||
sb.WriteString(wrapWith)
|
||||
sb.WriteString(str)
|
||||
sb.WriteString(wrapWith)
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// Unwrap 从另一个字符串中解开一个给定的字符串,将改变str值
|
||||
func Unwrap(str string, wrapToken string) string {
|
||||
if str == "" || wrapToken == "" {
|
||||
return str
|
||||
}
|
||||
|
||||
firstIndex := strings.Index(str, wrapToken)
|
||||
lastIndex := strings.LastIndex(str, wrapToken)
|
||||
|
||||
if firstIndex == 0 && lastIndex > 0 && lastIndex <= len(str)-1 {
|
||||
if len(wrapToken) <= lastIndex {
|
||||
str = str[len(wrapToken):lastIndex]
|
||||
}
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// RemoveHTML 去除字符串中的 html, js
|
||||
func RemoveHTML(str string) string {
|
||||
if len(str) > 0 {
|
||||
// 删除脚本
|
||||
reg := regexp.MustCompile(`([\r\n])[\s]+`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`<script[^>]*?>.*?</script>`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
// 删除HTML
|
||||
reg = regexp.MustCompile(`<(.[^>]*)>`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`([\r\n])[\s]+`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`-->`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`<!--.*`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`&(quot|#34);`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`&(amp|#38);`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`&(lt|#60);`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`&(gt|#62);`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`&(nbsp|#160);`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`&(iexcl|#161);`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`&(cent|#162);`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`&(pound|#163);`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`&(copy|#169);`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
reg = regexp.MustCompile(`&#(\d+);`)
|
||||
str = reg.ReplaceAllString(str, "")
|
||||
|
||||
str = strings.ReplaceAll(str, "<", "")
|
||||
str = strings.ReplaceAll(str, ">", "")
|
||||
str = strings.ReplaceAll(str, "\n", "")
|
||||
str = strings.ReplaceAll(str, " ", "")
|
||||
str = strings.ReplaceAll(str, " ", "")
|
||||
|
||||
return str
|
||||
}
|
||||
return ""
|
||||
}
|
||||
59
internal/pkg/token/jwt_maker.go
Normal file
59
internal/pkg/token/jwt_maker.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
tk "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const minSecretKeySize = 32
|
||||
|
||||
// JWTMaker JSON Web Token
|
||||
type JWTMaker struct {
|
||||
secretKey string
|
||||
}
|
||||
|
||||
// NewJWTMaker 创建一个新的JWTMaker
|
||||
func NewJWTMaker(secretKey string) error {
|
||||
if len(secretKey) < minSecretKeySize {
|
||||
return fmt.Errorf("invalid key size: must be at least %d characters", minSecretKeySize)
|
||||
}
|
||||
engine = &JWTMaker{secretKey}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateToken 根据用户名和时间创建一个新的token
|
||||
func (maker *JWTMaker) CreateToken(id string, username string, duration time.Duration) (string, *Payload, error) {
|
||||
payload, err := NewPayload(id, username, duration)
|
||||
if err != nil {
|
||||
return "", payload, err
|
||||
}
|
||||
|
||||
jwtToken := tk.NewWithClaims(tk.SigningMethodHS256, payload)
|
||||
token, err := jwtToken.SignedString([]byte(maker.secretKey))
|
||||
return token, payload, err
|
||||
}
|
||||
|
||||
// VerifyToken checks if the token is valid or not
|
||||
func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) {
|
||||
keyFunc := func(token *tk.Token) (interface{}, error) {
|
||||
_, ok := token.Method.(*tk.SigningMethodHMAC)
|
||||
if !ok {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
return []byte(maker.secretKey), nil
|
||||
}
|
||||
|
||||
jwtToken, err := tk.ParseWithClaims(token, &Payload{}, keyFunc)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
payload, ok := jwtToken.Claims.(*Payload)
|
||||
if !ok {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
24
internal/pkg/token/maker.go
Normal file
24
internal/pkg/token/maker.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var engine Maker
|
||||
|
||||
// Maker 管理token的接口定义
|
||||
type Maker interface {
|
||||
// CreateToken 根据用户名和时间创建一个新的token
|
||||
CreateToken(id string, username string, duration time.Duration) (string, *Payload, error)
|
||||
|
||||
// VerifyToken 校验token是否正确
|
||||
VerifyToken(token string) (*Payload, error)
|
||||
}
|
||||
|
||||
func CreateToken(id string, username string, duration time.Duration) (string, *Payload, error) {
|
||||
return engine.CreateToken(id, username, duration)
|
||||
}
|
||||
|
||||
func VerifyToken(token string) (*Payload, error) {
|
||||
return engine.VerifyToken(token)
|
||||
}
|
||||
60
internal/pkg/token/paseto_maker.go
Normal file
60
internal/pkg/token/paseto_maker.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"management/internal/config"
|
||||
|
||||
"github.com/aead/chacha20poly1305"
|
||||
"github.com/o1egl/paseto"
|
||||
)
|
||||
|
||||
// PasetoMaker is a PASETO token maker
|
||||
type PasetoMaker struct {
|
||||
paseto *paseto.V2
|
||||
symmetricKey []byte
|
||||
}
|
||||
|
||||
// NewPasetoMaker creates a new PasetoMaker
|
||||
func NewPasetoMaker() error {
|
||||
symmetricKey := config.File.JWT.SigningKey
|
||||
if len(symmetricKey) != chacha20poly1305.KeySize {
|
||||
return fmt.Errorf("invalid key size: must be exactly %d characters", chacha20poly1305.KeySize)
|
||||
}
|
||||
|
||||
engine = &PasetoMaker{
|
||||
paseto: paseto.NewV2(),
|
||||
symmetricKey: []byte(symmetricKey),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateToken creates a new token for a specific username and duration
|
||||
func (maker *PasetoMaker) CreateToken(id string, username string, duration time.Duration) (string, *Payload, error) {
|
||||
payload, err := NewPayload(id, username, duration)
|
||||
if err != nil {
|
||||
return "", payload, err
|
||||
}
|
||||
|
||||
token, err := maker.paseto.Encrypt(maker.symmetricKey, payload, nil)
|
||||
return token, payload, err
|
||||
}
|
||||
|
||||
// VerifyToken checks if the token is valid or not
|
||||
func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) {
|
||||
payload := &Payload{}
|
||||
|
||||
err := maker.paseto.Decrypt(token, maker.symmetricKey, payload, nil)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
err = payload.Valid()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
43
internal/pkg/token/payload.go
Normal file
43
internal/pkg/token/payload.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
tk "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// Different types of error returned by the VerifyToken function
|
||||
var (
|
||||
ErrInvalidToken = errors.New("token is invalid")
|
||||
ErrExpiredToken = errors.New("token has expired")
|
||||
)
|
||||
|
||||
// Payload contains the payload data of the token
|
||||
type Payload struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
tk.RegisteredClaims // v5版本新加的方法
|
||||
}
|
||||
|
||||
// NewPayload creates a new token payload with a specific username and duration
|
||||
func NewPayload(id string, username string, duration time.Duration) (*Payload, error) {
|
||||
payload := &Payload{
|
||||
id,
|
||||
username,
|
||||
tk.RegisteredClaims{
|
||||
ExpiresAt: tk.NewNumericDate(time.Now().Add(duration)), // 过期时间24小时
|
||||
IssuedAt: tk.NewNumericDate(time.Now()), // 签发时间
|
||||
NotBefore: tk.NewNumericDate(time.Now()), // 生效时间
|
||||
},
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// Valid checks if the token payload is valid or not
|
||||
func (payload *Payload) Valid() error {
|
||||
if time.Now().After(payload.ExpiresAt.Time) {
|
||||
return ErrExpiredToken
|
||||
}
|
||||
return nil
|
||||
}
|
||||
78
internal/pkg/validation/validation.go
Normal file
78
internal/pkg/validation/validation.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// 使用预编译的全局正则表达式,避免重复创建和编译.
|
||||
var (
|
||||
lengthRegex = regexp.MustCompile(`^.{3,20}$`) // 长度在 3 到 20 个字符之间
|
||||
validRegex = regexp.MustCompile(`^[A-Za-z0-9_]+$`) // 仅包含字母、数字和下划线
|
||||
letterRegex = regexp.MustCompile(`[A-Za-z]`) // 至少包含一个字母
|
||||
numberRegex = regexp.MustCompile(`\d`) // 至少包含一个数字
|
||||
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) // 邮箱格式
|
||||
phoneRegex = regexp.MustCompile(`^1[3-9]\d{9}$`) // 中国手机号
|
||||
)
|
||||
|
||||
// IsValidUsername 校验用户名是否合法.
|
||||
func IsValidUsername(username string) bool {
|
||||
// 校验长度
|
||||
if !lengthRegex.MatchString(username) {
|
||||
return false
|
||||
}
|
||||
// 校验字符合法性
|
||||
if !validRegex.MatchString(username) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsValidPassword 判断密码是否符合复杂度要求.
|
||||
func IsValidPassword(password string) error {
|
||||
switch {
|
||||
// 检查新密码是否为空
|
||||
case password == "":
|
||||
return errors.New("password cannot be empty")
|
||||
// 检查新密码的长度要求
|
||||
case len(password) < 6:
|
||||
return errors.New("password must be at least 6 characters long")
|
||||
// 使用正则表达式检查是否至少包含一个字母
|
||||
case !letterRegex.MatchString(password):
|
||||
return errors.New("password must contain at least one letter")
|
||||
// 使用正则表达式检查是否至少包含一个数字
|
||||
case !numberRegex.MatchString(password):
|
||||
return errors.New("password must contain at least one number")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValidEmail 判断电子邮件是否合法.
|
||||
func IsValidEmail(email string) error {
|
||||
// 检查电子邮件地址格式
|
||||
if email == "" {
|
||||
return errors.New("email cannot be empty")
|
||||
}
|
||||
|
||||
// 使用正则表达式校验电子邮件格式
|
||||
if !emailRegex.MatchString(email) {
|
||||
return errors.New("invalid email format")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValidPhone 判断手机号码是否合法.
|
||||
func IsValidPhone(phone string) error {
|
||||
// 检查手机号码格式
|
||||
if phone == "" {
|
||||
return errors.New("phone cannot be empty")
|
||||
}
|
||||
|
||||
// 使用正则表达式校验手机号码格式(假设是中国手机号,11位数字)
|
||||
if !phoneRegex.MatchString(phone) {
|
||||
return errors.New("invalid phone format")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
86
internal/pkg/validation/validator.go
Normal file
86
internal/pkg/validation/validator.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// 初始化一个验证器实例
|
||||
var validate = validator.New()
|
||||
|
||||
// 自定义验证规则:requiredint
|
||||
func init() {
|
||||
validate.RegisterValidation("telephone", func(fl validator.FieldLevel) bool {
|
||||
if err := IsValidPhone(fl.Field().String()); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
validate.RegisterValidation("dateonly", func(fl validator.FieldLevel) bool {
|
||||
_, err := time.ParseInLocation("2006-01-02", fl.Field().String(), time.Local)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func ValidateForm(s any) error {
|
||||
// 验证结构体数据
|
||||
if err := validate.Struct(s); err != nil {
|
||||
if _, ok := err.(*validator.InvalidValidationError); ok {
|
||||
return errors.New("验证器配置错误")
|
||||
}
|
||||
|
||||
// 获取结构体的反射类型
|
||||
t := reflect.TypeOf(s)
|
||||
|
||||
errorMessages := make([]string, 0)
|
||||
for _, err := range err.(validator.ValidationErrors) {
|
||||
// 获取字段名
|
||||
fieldName := err.Field()
|
||||
// 获取标签名
|
||||
tag := err.Tag()
|
||||
// 获取验证失败的字段值
|
||||
// value := err.Value()
|
||||
|
||||
// 通过反射获取字段
|
||||
field, found := t.Elem().FieldByName(fieldName)
|
||||
errorMsg := fmt.Sprintf("[%s] 验证失败: %s", fieldName, translate(tag))
|
||||
if found {
|
||||
commentTag := field.Tag.Get("comment")
|
||||
errorMsg = fmt.Sprintf("[%s] 验证失败: %s", commentTag, translate(tag))
|
||||
}
|
||||
|
||||
errorMessages = append(errorMessages, errorMsg)
|
||||
}
|
||||
return errors.New(strings.Join(errorMessages, "; "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func translate(s string) string {
|
||||
switch s {
|
||||
case "required":
|
||||
return "不能为空"
|
||||
case "min":
|
||||
return "不能小于"
|
||||
case "max":
|
||||
return "不能大于"
|
||||
case "email":
|
||||
return "不是有效的邮箱地址"
|
||||
case "telephone":
|
||||
return "不是有效的手机号"
|
||||
case "datetime":
|
||||
return "不是有效的日期时间"
|
||||
case "dateonly":
|
||||
return "不是有效的日期"
|
||||
}
|
||||
return s
|
||||
}
|
||||
Reference in New Issue
Block a user