gorm wire

This commit is contained in:
2025-05-07 14:12:53 +08:00
parent 461531c308
commit 68606c76f9
111 changed files with 1726 additions and 5809 deletions

View File

@@ -1,7 +0,0 @@
package config
type App struct {
Host string `mapstructure:"host" json:"host" yaml:"host"` // 服务地址
Port int `mapstructure:"port" json:"port" yaml:"port"` // 服务端口
Prod bool `mapstructure:"prod" json:"prod" yaml:"prod"` // 是否正式
}

View File

@@ -1,6 +0,0 @@
package config
type Applet struct {
AppID string `mapstructure:"app_id" json:"app_id" yaml:"app_id"` // appid
AppSecret string `mapstructure:"app_secret" json:"app_secret" yaml:"app_secret"` // secret
}

View File

@@ -1,9 +0,0 @@
package config
type Captcha struct {
OpenCaptcha int `mapstructure:"open_captcha" json:"open_captcha" yaml:"open_captcha"` // 是否开启防爆次数
OpenCaptchaTimeout string `mapstructure:"open_captcha_timeout" json:"open_captcha_timeout" yaml:"open_captcha_timeout"` // 缓存超时时间
ImgWidth int `mapstructure:"img_width" json:"img_width" yaml:"img_width"` // 验证码图片宽度
ImgHeight int `mapstructure:"img_height" json:"img_height" yaml:"img_height"` // 验证码图片高度
KeyLong int `mapstructure:"key_long" json:"key_long" yaml:"key_long"` // 验证码长度
}

View File

@@ -3,91 +3,73 @@ package config
import (
"fmt"
"path/filepath"
"strings"
"time"
"github.com/fsnotify/fsnotify"
"github.com/spf13/viper"
)
var File *Config
func New(path string) (*Config, error) {
v := viper.New()
v.AddConfigPath(filepath.Dir(path))
v.SetConfigName(filepath.Base(path))
v.SetConfigType(strings.TrimPrefix(filepath.Ext(path), "."))
if err := v.ReadInConfig(); err != nil {
return nil, err
}
const ConfigDefaultFile = "config.dev.yaml"
v.WatchConfig()
var config Config
v.OnConfigChange(func(e fsnotify.Event) {
fmt.Println("config file changed:", e.Name)
if err := v.Unmarshal(&config); err != nil {
fmt.Println(err)
}
})
if err := v.Unmarshal(&config); err != nil {
return nil, err
}
return &config, nil
}
type Config struct {
App App `mapstructure:"app" json:"app" yaml:"app"`
DB DB `mapstructure:"db" json:"db" yaml:"db"`
Redis Redis `mapstructure:"redis" json:"redis" yaml:"redis"`
Cors Cors `mapstructure:"cors" json:"cors" yaml:"cors"`
JWT JWT `mapstructure:"jwt" json:"jwt" yaml:"jwt"`
AliyunUpload AliyunUpload `mapstructure:"aliyunupload" json:"aliyunupload" yaml:"aliyunupload"`
TencentUpload TencentUpload `mapstructure:"tencentupload" json:"tencentupload" yaml:"tencentupload"`
Captcha Captcha `mapstructure:"captcha" json:"captcha"`
Applet Applet `mapstructure:"applet" json:"applet" yaml:"applet"`
Smb Smb `mapstructure:"smb" json:"smb" yaml:"smb"`
}
func New(path string) (*Config, error) {
fp := "."
fn := ConfigDefaultFile
if len(path) > 0 {
fp, fn = filepath.Split(path)
if len(fp) == 0 {
fp = "."
}
}
v := viper.New()
v.AddConfigPath(fp)
v.SetConfigName(fn)
v.SetConfigType("yaml")
if err := v.ReadInConfig(); err != nil {
return nil, err
}
v.WatchConfig()
var conf *Config
v.OnConfigChange(func(e fsnotify.Event) {
fmt.Println("config file changed:", e.Name)
if err := v.Unmarshal(&conf); err != nil {
fmt.Println(err)
}
})
if err := v.Unmarshal(&conf); err != nil {
return nil, err
}
return conf, nil
}
func Init(path string) error {
fp := "."
fn := ConfigDefaultFile
if len(path) > 0 {
fp, fn = filepath.Split(path)
if len(fp) == 0 {
fp = "."
}
}
v := viper.New()
v.AddConfigPath(fp)
v.SetConfigName(fn)
v.SetConfigType("yaml")
if err := v.ReadInConfig(); err != nil {
return err
}
v.WatchConfig()
v.OnConfigChange(func(e fsnotify.Event) {
fmt.Println("config file changed:", e.Name)
if err := v.Unmarshal(&File); err != nil {
fmt.Println(err)
}
})
if err := v.Unmarshal(&File); err != nil {
return err
}
return nil
App struct {
Host string `mapstructure:"host"` // 服务地址
Port int `mapstructure:"port"` // 服务端口
Prod bool `mapstructure:"prod"` // 是否正式
} `mapstructure:"app"`
DB struct {
Driver string `mapstructure:"driver"` // 数据库类型
Host string `mapstructure:"host"` // 数据库地址
Port int `mapstructure:"port"` // 数据库端口
Username string `mapstructure:"username"` // 数据库用户
Password string `mapstructure:"password"` // 数据库密码
DBName string `mapstructure:"db_name"` // 数据库名称
MaxOpenConns int `mapstructure:"max_open_conns"` // 数据库名称
MaxIdleConns int `mapstructure:"max_idle_conns"` // 数据库名称
} `mapstructure:"db"`
Redis struct {
Host string `mapstructure:"host"` // redis地址
Port int `mapstructure:"port"` // redis端口
Password string `mapstructure:"password"` // redis密码
DB int `mapstructure:"db"` // redis数据库
} `mapstructure:"redis"`
Cors struct {
Host string `mapstructure:"host"`
} `mapstructure:"cors"`
JWT struct {
SigningKey string `mapstructure:"signing_key"` // jwt签名
ExpiresTime time.Duration `mapstructure:"expires_time"` // 过期时间
RefreshTime time.Duration `mapstructure:"refresh_time"` // 刷新过期时间
Issuer string `mapstructure:"issuer"` // 签发者
} `mapstructure:"jwt"`
Captcha struct {
OpenCaptcha int `mapstructure:"open_captcha"` // 是否开启防爆次数
ImgWidth int `mapstructure:"img_width"` // 验证码图片宽度
ImgHeight int `mapstructure:"img_height"` // 验证码图片高度
KeyLong int `mapstructure:"key_long"` // 验证码长度
} `mapstructure:"captcha"`
}

View File

@@ -1,5 +0,0 @@
package config
type Cors struct {
Host string `mapstructure:"host" json:"host" yaml:"host"`
}

View File

@@ -1,12 +0,0 @@
package config
type DB struct {
Driver string `mapstructure:"driver" json:"driver" yaml:"driver"` // 数据库类型
Host string `mapstructure:"host" json:"host" yaml:"host"` // 数据库地址
Port int `mapstructure:"port" json:"port" yaml:"port"` // 数据库端口
Username string `mapstructure:"username" json:"username" yaml:"username"` // 数据库用户
Password string `mapstructure:"password" json:"password" yaml:"password"` // 数据库密码
DBName string `mapstructure:"db_name" json:"db_name" yaml:"db_name"` // 数据库名称
MaxOpenConns int `mapstructure:"max_open_conns" json:"max_open_conns" yaml:"max_open_conns"` // 数据库名称
MaxIdleConns int `mapstructure:"max_idle_conns" json:"max_idle_conns" yaml:"max_idle_conns"` // 数据库名称
}

View File

@@ -1,10 +0,0 @@
package config
import "time"
type JWT struct {
SigningKey string `mapstructure:"signing_key" json:"signing_key" yaml:"signing_key"` // jwt签名
ExpiresTime time.Duration `mapstructure:"expires_time" json:"expires_time" yaml:"expires_time"` // 过期时间
RefreshTime time.Duration `mapstructure:"refresh_time" json:"refresh_time" yaml:"refresh_time"` // 刷新过期时间
Issuer string `mapstructure:"issuer" json:"issuer" yaml:"issuer"` // 签发者
}

View File

@@ -1,8 +0,0 @@
package config
type Redis struct {
Host string `mapstructure:"host" json:"host" yaml:"host"` // redis地址
Port int `mapstructure:"port" json:"port" yaml:"port"` // redis端口
Password string `mapstructure:"password" json:"password" yaml:"password"` // redis密码
DB int `mapstructure:"db" json:"db_name" yaml:"db"` // redis数据库
}

View File

@@ -1,8 +0,0 @@
package config
type Smb struct {
Host string `mapstructure:"host" json:"host" yaml:"host"`
Name string `mapstructure:"name" json:"name" yaml:"name"`
Pass string `mapstructure:"pass" json:"pass" yaml:"pass"`
Mount string `mapstructure:"mount" json:"mount" yaml:"mount"`
}

View File

@@ -1,18 +0,0 @@
package config
type AliyunUpload struct {
Endpoint string `mapstructure:"endpoint" json:"endpoint" yaml:"endpoint"`
Bucket string `mapstructure:"bucket" json:"bucket" yaml:"bucket"`
AccessKeyID string `mapstructure:"access_key_id" json:"access_key_id" yaml:"access_key_id"`
AccessKeySecret string `mapstructure:"access_key_secret" json:"access_key_secret" yaml:"access_key_secret"`
}
type TencentUpload struct {
Region string `mapstructure:"region" json:"region" yaml:"region"`
Bucket string `mapstructure:"bucket" json:"bucket" yaml:"bucket"`
AccessKeyID string `mapstructure:"access_key_id" json:"access_key_id" yaml:"access_key_id"`
AccessKeySecret string `mapstructure:"access_key_secret" json:"access_key_secret" yaml:"access_key_secret"`
AllowImageMaxSize int64 `mapstructure:"allow_image_max_size" json:"allow_image_max_size" yaml:"allow_image_max_size"`
AllowImageExtension string `mapstructure:"allow_image_extension" json:"allow_image_extension" yaml:"allow_image_extension"`
AllowFileMaxSize int64 `mapstructure:"allow_file_max_size" json:"allow_file_max_size" yaml:"allow_file_max_size"`
}

View File

@@ -1,233 +0,0 @@
package fetcher
import (
"bytes"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"time"
)
// Get get请求
func Get(url string, timeout time.Duration) ([]byte, int, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, 0, fmt.Errorf("[get] new request err: %v", err)
}
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("[get] client do err: %v", err)
}
defer resp.Body.Close()
status := resp.StatusCode
res, err := io.ReadAll(resp.Body)
if err != nil {
return nil, status, fmt.Errorf("read response err: %v", err)
}
return res, status, nil
}
// GetString 请求
func GetString(url string, parameter map[string]string, timeout time.Duration) ([]byte, int, error) {
byteParameter := new(bytes.Buffer)
w := multipart.NewWriter(byteParameter)
for k, v := range parameter {
w.WriteField(k, v)
}
w.Close()
request, err := http.NewRequest("GET", url, byteParameter)
if err != nil {
return nil, 0, fmt.Errorf("[getstring] new request err: %v", err)
}
request.Header.Set("Content-Type", w.FormDataContentType())
client := &http.Client{Timeout: time.Second * timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, err
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(response.Body)
status := response.StatusCode
resp, err := io.ReadAll(response.Body)
if err != nil {
return nil, status, err
}
return resp, status, nil
}
// GetJson application/json get 请求
func GetJson(url string, parameter []byte, timeout time.Duration) ([]byte, int, error) {
byteParameter := bytes.NewBuffer(parameter)
req, err := http.NewRequest("GET", url, byteParameter)
if err != nil {
return nil, 0, fmt.Errorf("[getjson] new request err: %v", err)
}
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("[getjson] client do err: %v", err)
}
defer resp.Body.Close()
status := resp.StatusCode
res, err := io.ReadAll(resp.Body)
if err != nil {
return nil, status, fmt.Errorf("read response err: %v", err)
}
return res, status, nil
}
// GetJsonWithToken Get json with token 请求
func GetJsonWithToken(url string, token string, timeout time.Duration) ([]byte, int, error) {
request, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, 0, err
}
request.Header.Set("Content-type", "application/json")
request.Header.Set("Access-Token", token)
client := &http.Client{Timeout: timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, errors.New("请求网络错误")
}
defer response.Body.Close()
status := response.StatusCode
resp, err := io.ReadAll(response.Body)
if err != nil {
return nil, status, err
}
return resp, status, nil
}
// PostJson application/json post 请求
func PostJson(url string, parameter []byte, timeout time.Duration) ([]byte, int, error) {
byteParameter := bytes.NewBuffer(parameter)
request, err := http.NewRequest("POST", url, byteParameter)
if err != nil {
return nil, 0, fmt.Errorf("[postjson] new request err: %v", err)
}
request.Header.Set("Content-type", "application/json")
client := &http.Client{Timeout: timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, fmt.Errorf("[postjson] client do err: %v", err)
}
defer response.Body.Close()
status := response.StatusCode
all, err := io.ReadAll(response.Body)
if err != nil {
return nil, status, fmt.Errorf("read response err: %v", err)
}
return all, status, nil
}
// PostString application/x-www-form-urlencoded 请求
func PostString(url string, parameter []byte, timeout time.Duration) ([]byte, int, error) {
byteParameter := bytes.NewBuffer(parameter)
request, err := http.NewRequest("POST", url, byteParameter)
if err != nil {
return nil, 0, fmt.Errorf("[poststring] new request err: %v", err)
}
request.Header.Set("Content-type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, fmt.Errorf("[poststring] client do err: %v", err)
}
defer response.Body.Close()
status := response.StatusCode
all, err := io.ReadAll(response.Body)
if err != nil {
return nil, status, fmt.Errorf("read response err: %v", err)
}
return all, status, nil
}
// PostJsonWithToken Post json with token 请求
func PostJsonWithToken(url string, parameter []byte, token string, timeout time.Duration) ([]byte, int, error) {
bufParameter := bytes.NewBuffer(parameter)
request, err := http.NewRequest("POST", url, bufParameter)
if err != nil {
return nil, 0, err
}
request.Header.Set("Content-type", "application/json")
request.Header.Set("Access-Token", token)
client := &http.Client{Timeout: timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, fmt.Errorf("请求网络错误: %v", err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(response.Body)
stauts := response.StatusCode
resp, err := io.ReadAll(response.Body)
if err != nil {
return nil, stauts, errors.New("网络请求结果读取失败")
}
return resp, stauts, nil
}
// PostJsonWithBearerToken Post json with bearer token 请求
func PostJsonWithBearerToken(url string, parameter []byte, token string, timeout time.Duration) ([]byte, int, error) {
bufParameter := bytes.NewBuffer(parameter)
request, err := http.NewRequest("POST", url, bufParameter)
if err != nil {
return nil, 0, err
}
request.Header.Set("Content-type", "application/json")
request.Header.Set("Authorization", fmt.Sprintf("bearer %s", token))
client := &http.Client{Timeout: timeout}
response, err := client.Do(request)
if err != nil {
return nil, 0, fmt.Errorf("请求网络错误: %v", err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(response.Body)
status := response.StatusCode
resp, err := io.ReadAll(response.Body)
if err != nil {
return nil, status, errors.New("网络请求结果读取失败")
}
return resp, status, nil
}
// func determiningEncoding(r *bufio.Reader) encoding.Encoding {
// b, err := r.Peek(1024)
// if err != nil && err != io.EOF {
// log.Printf("Fetcher error: %v", err)
// return unicode.UTF8
// }
// e, _, _ := charset.DetermineEncoding(b, "")
// return e
// }

View File

@@ -1,26 +0,0 @@
package gu
import (
"net/http"
"github.com/gin-gonic/gin"
)
func Cors() gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Headers", "Content-Type, AccessToken, X-CSRF-Token, Authorization, Token")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type")
c.Header("Access-Control-Allow-Credentials", "true")
// 放行所有OPTIONS方法
if method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
}
// 处理请求
c.Next()
}
}

View File

@@ -1,40 +0,0 @@
package gu
import (
"net/http"
"github.com/gin-gonic/gin"
)
type response struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
}
type PageData struct {
Total int64 `json:"total"`
PageID int32 `json:"page_id"`
PageSize int32 `json:"page_size"`
Result any `json:"result"`
}
func Ok(ctx *gin.Context, data any) {
ResponseJson(ctx, http.StatusOK, "ok", data)
}
func Failed(ctx *gin.Context, message string) {
ResponseJson(ctx, http.StatusInternalServerError, message, nil)
}
func FailedWithCode(ctx *gin.Context, code int, message string) {
ResponseJson(ctx, code, message, nil)
}
func ResponseJson(ctx *gin.Context, code int, message string, data any) {
ctx.JSON(code, response{
Code: code,
Message: message,
Data: data,
})
}

View File

@@ -1,57 +0,0 @@
package gu
import (
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/locales/en"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
enTranslations "github.com/go-playground/validator/v10/translations/en"
chTranslations "github.com/go-playground/validator/v10/translations/zh"
)
var trans ut.Translator
// loca 通常取决于 http 请求头的 'Accept-Language'
func SetValidatorTrans(local string) (err error) {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
zhT := zh.New() // chinese
enT := en.New() // english
uni := ut.New(enT, zhT, enT)
var o bool
trans, o = uni.GetTranslator(local)
if !o {
return fmt.Errorf("uni.GetTranslator(%s) failed", local)
}
// register translate
// 注册翻译器
switch local {
case "en":
err = enTranslations.RegisterDefaultTranslations(v, trans)
case "zh":
err = chTranslations.RegisterDefaultTranslations(v, trans)
default:
err = enTranslations.RegisterDefaultTranslations(v, trans)
}
return
}
return
}
func ValidatorErrors(ctx *gin.Context, err error) {
if errors, ok := err.(validator.ValidationErrors); ok {
errs := gin.H{}
for _, e := range errors {
errs[e.StructField()] = strings.Replace(e.Translate(trans), e.StructField(), "", -1)
}
ctx.JSON(http.StatusBadRequest, errs)
} else {
ctx.JSON(http.StatusBadRequest, err)
}
}

View File

@@ -21,35 +21,35 @@ const (
)
var (
// pear admin 配置
PearAdmin = "m:pearjson"
// PearAdmin 配置
PearAdmin = "m:pear_json"
// 所有类别
// AllCategories 所有类别
AllCategories = "m:category:all"
// 所有类别 简单信息
AllCategorySimple = "m:categorysimple:all"
// 类别列表 根据 父id 获取
// AllCategorySimple 所有类别 简单信息
AllCategorySimple = "m:category_simple:all"
// ListCategoriesByParentID 类别列表 根据 父id 获取
ListCategoriesByParentID = "m:category:parent_id:%d"
// 所有部门
// AllDepartments 所有部门
AllDepartments = "m:department:all"
// 所有菜单
// AllMenus 所有菜单
AllMenus = "m:menus:all"
// 递归菜单
// RecursiveMenus 递归菜单
RecursiveMenus = "m:rec_menus:%d"
// 根据用户ID获取菜单
// AdminMenus 根据用户ID获取菜单
AdminMenus = "m:admin_menus:%d"
// 登陆用户的菜单
// OwnerMenus 登陆用户的菜单
OwnerMenus = "m:owner_menus:%d"
// 登陆用户的菜单
// OwnerMenusMap 登陆用户的菜单
OwnerMenusMap = "m:owner_menus_map:%d"
// 所有角色
// AllRoles 所有角色
AllRoles = "m:role:all"
)
func GetManageKey(ctx context.Context, key string, arg ...any) string {
func GetManageKey(_ context.Context, key string, arg ...any) string {
key = fmt.Sprintf(key, arg...)
return key
}

View File

@@ -1,54 +0,0 @@
package logger
import (
"os"
"time"
"management/internal/pkg/config"
"github.com/natefinch/lumberjack"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func New(prod bool) {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
logRotate := &lumberjack.Logger{
Filename: "./log/run.log", // 日志文件的位置
MaxSize: 10, // 在进行切割之前,日志文件的最大大小(以MB为单位)
MaxBackups: 100, // 保留旧文件的最大个数
MaxAge: 30, // 保留旧文件的最大天数
Compress: true,
}
zerolog.TimeFieldFormat = time.DateTime
log.Logger = log.With().Caller().Logger()
if prod {
log.Logger = log.Output(logRotate)
} else {
consoleWriter := zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.DateTime}
multi := zerolog.MultiLevelWriter(consoleWriter, logRotate)
log.Logger = log.Output(multi)
}
}
func Init() {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
logRotate := &lumberjack.Logger{
Filename: "./log/run.log", // 日志文件的位置
MaxSize: 10, // 在进行切割之前,日志文件的最大大小(以MB为单位)
MaxBackups: 100, // 保留旧文件的最大个数
MaxAge: 30, // 保留旧文件的最大天数
Compress: true,
}
zerolog.TimeFieldFormat = time.DateTime
log.Logger = log.With().Caller().Logger()
if config.File.App.Prod {
log.Logger = log.Output(logRotate)
} else {
consoleWriter := zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.DateTime}
multi := zerolog.MultiLevelWriter(consoleWriter, logRotate)
log.Logger = log.Output(multi)
}
}

View File

@@ -29,5 +29,5 @@ func (m *middleware) writeLog(req *http.Request, start time.Time) {
c, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
_ = m.auditLogsvc.Create(c, al)
_ = m.auditLogService.Create(c, al)
}

View File

@@ -28,21 +28,14 @@ func (m *middleware) Authorize(next http.Handler) http.Handler {
return
}
if user == nil {
http.Error(w, "user not found", http.StatusUnauthorized)
return
}
// 登陆成功 判断权限
// 默认权限判断
path := r.URL.Path
if b, ok := defaultMenus[path]; ok && b {
next.ServeHTTP(w, r)
return
}
menus, err := m.menusvc.ListByRoleIDToMap(ctx, user.RoleID)
menus, err := m.menuService.ListByRoleIDToMap(ctx, user.RoleID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@@ -63,10 +56,9 @@ func (m *middleware) isLogin(ctx context.Context) (*dto.AuthorizeUser, bool) {
if exists := m.session.Exists(ctx, know.StoreName); exists {
b := m.session.GetBytes(ctx, know.StoreName)
var user dto.AuthorizeUser
if err := json.Unmarshal(b, &user); err != nil {
return nil, false
if err := json.Unmarshal(b, &user); err == nil && user.ID > 0 {
return &user, true
}
return &user, true
}
return nil, false
@@ -84,7 +76,7 @@ func (m *middleware) AuthUser(ctx context.Context) dto.AuthorizeUser {
func (m *middleware) IsAuth(ctx context.Context) bool {
var user dto.AuthorizeUser
b := m.session.GetBytes(ctx, know.StoreName)
if err := json.Unmarshal(b, &user); err == nil {
if err := json.Unmarshal(b, &user); err == nil && user.ID > 0 {
return true
}
return false

View File

@@ -21,17 +21,15 @@ type Middleware interface {
}
type middleware struct {
session session.Session
menusvc v1.MenuService
auditLogsvc v1.AuditLogService
session session.Session
menuService v1.MenuService
auditLogService v1.AuditLogService
}
var _ Middleware = (*middleware)(nil)
func New(session session.Session, menusvc v1.MenuService, auditLogsvc v1.AuditLogService) Middleware {
func New(session session.Session, menuService v1.MenuService, auditLogService v1.AuditLogService) Middleware {
return &middleware{
session: session,
menusvc: menusvc,
auditLogsvc: auditLogsvc,
session: session,
menuService: menuService,
auditLogService: auditLogService,
}
}

View File

@@ -10,12 +10,13 @@ import (
"management/internal/pkg/config"
"github.com/drhin/logger"
"github.com/redis/go-redis/v9"
)
var ErrRedisKeyNotFound = errors.New("redis key not found")
type RedisCache interface {
type Cache interface {
Encode(a any) ([]byte, error)
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error
Del(ctx context.Context, keys ...string) error
@@ -27,25 +28,29 @@ type RedisCache interface {
}
type redisCache struct {
engine *redis.Client
client *redis.Client
}
var _ RedisCache = (*redisCache)(nil)
func New(conf config.Redis) (*redisCache, error) {
func New(conf *config.Config, log *logger.Logger) (Cache, func(), error) {
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", conf.Host, conf.Port),
Password: conf.Password,
DB: conf.DB,
Addr: fmt.Sprintf("%s:%d", conf.Redis.Host, conf.Redis.Port),
Password: conf.Redis.Password,
DB: conf.Redis.DB,
})
_, err := rdb.Ping(context.Background()).Result()
if err != nil {
return nil, err
return nil, nil, err
}
cleanup := func() {
if err := rdb.Close(); err != nil {
log.Error("redis close error", err)
}
}
return &redisCache{
engine: rdb,
}, nil
client: rdb,
}, cleanup, nil
}
func (r *redisCache) Encode(a any) ([]byte, error) {
@@ -59,18 +64,18 @@ func (r *redisCache) Encode(a any) ([]byte, error) {
// Set 设置值
func (r *redisCache) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
return r.engine.Set(ctx, key, value, expiration).Err()
return r.client.Set(ctx, key, value, expiration).Err()
}
// Del 删除键值
func (r *redisCache) Del(ctx context.Context, keys ...string) error {
return r.engine.Del(ctx, keys...).Err()
return r.client.Del(ctx, keys...).Err()
}
// Get 获取值
func (r *redisCache) Get(ctx context.Context, key string) (string, error) {
val, err := r.engine.Get(ctx, key).Result()
if err == redis.Nil {
val, err := r.client.Get(ctx, key).Result()
if errors.Is(err, redis.Nil) {
return "", ErrRedisKeyNotFound
} else if err != nil {
return "", fmt.Errorf("cannot get value with:[%s]: %v", key, err)
@@ -81,8 +86,8 @@ func (r *redisCache) Get(ctx context.Context, key string) (string, error) {
// GetBytes 获取值
func (r *redisCache) GetBytes(ctx context.Context, key string) ([]byte, error) {
val, err := r.engine.Get(ctx, key).Bytes()
if err == redis.Nil {
val, err := r.client.Get(ctx, key).Bytes()
if errors.Is(err, redis.Nil) {
return nil, ErrRedisKeyNotFound
} else if err != nil {
return nil, fmt.Errorf("cannot get value with:[%s]: %v", key, err)
@@ -92,15 +97,15 @@ func (r *redisCache) GetBytes(ctx context.Context, key string) ([]byte, error) {
}
func (r *redisCache) Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd {
return r.engine.Scan(ctx, cursor, match, count)
return r.client.Scan(ctx, cursor, match, count)
}
func (r *redisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
return r.engine.Keys(ctx, pattern).Result()
return r.client.Keys(ctx, pattern).Result()
}
func (r *redisCache) ListKeys(ctx context.Context, pattern string, pageID int, pageSize int) ([]string, int, error) {
all, err := r.engine.Keys(ctx, pattern).Result()
all, err := r.client.Keys(ctx, pattern).Result()
if err != nil {
return nil, 0, err
}
@@ -115,7 +120,7 @@ func (r *redisCache) ListKeys(ctx context.Context, pattern string, pageID int, p
for {
var scanResult []string
var err error
scanResult, cursor, err = r.engine.Scan(ctx, cursor, pattern, int64(pageSize)).Result()
scanResult, cursor, err = r.client.Scan(ctx, cursor, pattern, int64(pageSize)).Result()
if err != nil {
return nil, count, err
}
@@ -135,36 +140,3 @@ func (r *redisCache) ListKeys(ctx context.Context, pattern string, pageID int, p
}
return keys[startIndex:endIndex], count, nil
}
// ==========================
func Encode(a any) ([]byte, error) {
return nil, nil
}
func Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
return nil
}
func Del(ctx context.Context, keys ...string) error {
return nil
}
func Get(ctx context.Context, key string) (string, error) {
return "", nil
}
func GetBytes(ctx context.Context, key string) ([]byte, error) {
return nil, nil
}
func Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd {
return nil
}
func Keys(ctx context.Context, pattern string) ([]string, error) {
return nil, nil
}
func ListKeys(ctx context.Context, pattern string, pageID int, pageSize int) ([]string, int, error) {
return nil, 0, nil
}

View File

@@ -1,4 +1,4 @@
package tpl
package funcs
import (
"html/template"
@@ -6,7 +6,7 @@ import (
"time"
)
func (r *render) Methods() map[string]any {
func Methods() map[string]any {
res := make(map[string]any, 1)
res["dateFormat"] = func(dt time.Time) template.HTML {
@@ -45,10 +45,10 @@ func (r *render) Methods() map[string]any {
res["expandTags"] = func(s []string) template.HTML {
if len(s) == 0 {
return template.HTML("")
return ""
}
if len(s) == 1 && s[0] == "all" {
return template.HTML("")
return ""
}
return template.HTML(strings.Join(s, ","))
}

View File

@@ -1,4 +1,4 @@
package tpl
package gen
import (
"html/template"
@@ -8,17 +8,17 @@ import (
"management/internal/erpserver/model/dto"
)
func (r *render) btnFuncs() map[string]any {
func Button() map[string]any {
res := make(map[string]any, 3)
res["genBtn"] = func(btns []*dto.OwnerMenuDto, actionNames ...string) template.HTML {
if len(btns) == 0 {
return template.HTML("")
res["genBtn"] = func(buttons []*dto.OwnerMenuDto, actionNames ...string) template.HTML {
if len(buttons) == 0 {
return ""
}
var res string
for _, action := range actionNames {
for _, btn := range btns {
for _, btn := range buttons {
btn.Style = strings.ReplaceAll(btn.Style, "pear", "layui")
base := filepath.Base(btn.Url)
if base == action {
@@ -34,14 +34,14 @@ func (r *render) btnFuncs() map[string]any {
return template.HTML(res)
}
res["genLink"] = func(btns []*dto.OwnerMenuDto, actionNames ...string) template.HTML {
if len(btns) == 0 {
return template.HTML("")
res["genLink"] = func(buttons []*dto.OwnerMenuDto, actionNames ...string) template.HTML {
if len(buttons) == 0 {
return ""
}
var res string
for _, action := range actionNames {
for _, btn := range btns {
for _, btn := range buttons {
btn.Style = strings.ReplaceAll(btn.Style, "pear", "layui")
base := filepath.Base(btn.Url)
if base == action {
@@ -64,14 +64,14 @@ func (r *render) btnFuncs() map[string]any {
return template.HTML(res)
}
res["submitBtn"] = func(btns []*dto.OwnerMenuDto, actionNames ...string) template.HTML {
if len(btns) == 0 {
return template.HTML("")
res["submitBtn"] = func(buttons []*dto.OwnerMenuDto, actionNames ...string) template.HTML {
if len(buttons) == 0 {
return ""
}
var res string
for _, action := range actionNames {
for _, btn := range btns {
for _, btn := range buttons {
btn.Style = strings.ReplaceAll(btn.Style, "pear", "layui")
base := filepath.Base(btn.Url)
if base == action {
@@ -89,3 +89,11 @@ func (r *render) btnFuncs() map[string]any {
return res
}
func firstLower(s string) string {
if len(s) == 0 {
return s
}
return strings.ToLower(s[:1]) + s[1:]
}

106
internal/pkg/render/html.go Normal file
View File

@@ -0,0 +1,106 @@
package render
import (
"bytes"
"context"
"encoding/json"
"fmt"
"html/template"
"net/http"
"path/filepath"
"strings"
"management/internal/erpserver/model/dto"
"management/internal/pkg/know"
"github.com/justinas/nosurf"
)
type TemplateConfig struct {
Root string
Extension string
Layout string
Partial string
}
type HtmlData struct {
IsAuthenticated bool
AuthorizeUser dto.AuthorizeUser
AuthorizeMenus []*dto.OwnerMenuDto
Data any
}
func (r *render) HTML(w http.ResponseWriter, req *http.Request, tpl string, data map[string]any) {
name := strings.ReplaceAll(tpl, "/", "_")
t, ok := r.templates[name]
if !ok {
http.Error(w, "template is empty", http.StatusInternalServerError)
return
}
hd := r.setDefaultData(req, data)
buf := new(bytes.Buffer)
err := t.ExecuteTemplate(buf, filepath.Base(tpl), hd)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, err = buf.WriteTo(w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
func (r *render) setDefaultData(req *http.Request, data map[string]any) map[string]any {
if data == nil {
data = make(map[string]any)
}
ctx := req.Context()
isAuth := r.session.Exists(ctx, know.StoreName)
data["IsAuthenticated"] = isAuth
if isAuth {
var authUser dto.AuthorizeUser
u := r.session.GetBytes(ctx, know.StoreName)
_ = json.Unmarshal(u, &authUser)
data["AuthorizeMenus"] = r.getCurrentPathButtons(ctx, authUser.RoleID, req.URL.Path)
}
token := nosurf.Token(req)
data["CsrfToken"] = token
data["CsrfTokenField"] = template.HTML(fmt.Sprintf(`<input type="hidden" name="csrf_token" value="%s" />`, token))
return data
}
func (r *render) getCurrentPathButtons(ctx context.Context, roleID int32, path string) []*dto.OwnerMenuDto {
var res []*dto.OwnerMenuDto
// 获取当前登陆角色的权限
menus, err := r.menuService.ListByRoleIDToMap(ctx, roleID)
if err != nil {
return res
}
menu, ok := menus[path]
if !ok {
return res
}
for _, item := range menus {
if menu.IsList {
if item.ParentID == menu.ID || item.ID == menu.ID {
res = append(res, item)
}
} else {
if item.ParentID == menu.ParentID {
res = append(res, item)
}
}
}
return res
}

View File

@@ -1,4 +1,4 @@
package tpl
package render
import (
"encoding/json"
@@ -11,12 +11,12 @@ type Response struct {
Data any `json:"data"`
}
type ResponseDtree struct {
Status ResponseDtreeStatus `json:"status"`
Data any `json:"data"`
type ResponseTree struct {
Status ResponseTreeStatus `json:"status"`
Data any `json:"data"`
}
type ResponseDtreeStatus struct {
type ResponseTreeStatus struct {
Code int `json:"code"`
Message string `json:"message"`
}
@@ -28,18 +28,18 @@ type ResponseList struct {
Data any `json:"data"`
}
func (r *render) JSONF(w http.ResponseWriter, success bool, message string) {
r.JSON(w, Response{Success: success, Message: message})
}
func (r *render) JSONOK(w http.ResponseWriter, message string) {
func (r *render) JSONOk(w http.ResponseWriter, message string) {
r.JSON(w, Response{Success: true, Message: message})
}
func (r *render) JSONERR(w http.ResponseWriter, message string) {
func (r *render) JSONErr(w http.ResponseWriter, message string) {
r.JSON(w, Response{Success: false, Message: message})
}
func (r *render) JSONObj(w http.ResponseWriter, message string, data any) {
r.JSON(w, Response{Success: true, Message: message, Data: data})
}
func (r *render) JSON(w http.ResponseWriter, data any) {
v, err := json.Marshal(data)
if err != nil {

View File

@@ -0,0 +1,60 @@
package render
import (
"html/template"
"net/http"
v1 "management/internal/erpserver/service/v1"
"management/internal/pkg/render/util"
"management/internal/pkg/session"
)
type Render interface {
htmlRender
jsonRender
}
type htmlRender interface {
HTML(w http.ResponseWriter, req *http.Request, name string, data map[string]any)
}
type jsonRender interface {
JSON(w http.ResponseWriter, data any)
JSONObj(w http.ResponseWriter, message string, data any)
JSONOk(w http.ResponseWriter, message string)
JSONErr(w http.ResponseWriter, message string)
}
type render struct {
templateConfig *TemplateConfig
templates map[string]*template.Template
session session.Session
menuService v1.MenuService
}
func New(session session.Session, menuService v1.MenuService) (Render, error) {
r := &render{
templateConfig: &TemplateConfig{
Root: ".",
Extension: ".tmpl",
Layout: "base",
Partial: "partial",
},
session: session,
menuService: menuService,
}
var err error
r.templates, err = util.CreateTemplateCache(
r.templateConfig.Root,
r.templateConfig.Layout,
r.templateConfig.Partial,
r.templateConfig.Extension,
)
if err != nil {
return nil, err
}
return r, nil
}

View File

@@ -0,0 +1,99 @@
package util
import (
"fmt"
"html/template"
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"management/internal/pkg/render/funcs"
"management/internal/pkg/render/gen"
templates "management/web/templates/manage"
)
func CreateTemplateCache(root, layout, partial, extension string) (map[string]*template.Template, error) {
cache := make(map[string]*template.Template)
pages, err := getFiles(root, extension)
if err != nil {
return nil, err
}
layoutAndPartial, err := getLayoutAndPartials(layout, partial, extension)
if err != nil {
return nil, err
}
for _, page := range pages {
if strings.HasPrefix(page, layout) || strings.HasSuffix(page, partial) {
continue
}
name := filepath.Base(page)
pathArr := strings.Split(page, "/")
dir := pathArr[len(pathArr)-2 : len(pathArr)-1]
templateName := fmt.Sprintf("%s_%s", dir[0], name)
ts := template.Must(template.New(templateName).Funcs(gen.Button()).Funcs(funcs.Methods()).ParseFS(templates.TemplateFS, page))
ts, err = ts.ParseFS(templates.TemplateFS, layoutAndPartial...)
if err != nil {
return nil, err
}
cache[templateName] = ts
}
return cache, nil
}
func getLayoutAndPartials(layout, partial, extension string) ([]string, error) {
layouts, err := getFiles(layout, extension)
if err != nil {
return nil, err
}
partials, err := getFiles(partial, extension)
if err != nil {
return nil, err
}
return slices.Concat(layouts, partials), nil
}
func getFiles(path string, suffix string) ([]string, error) {
files := make([]string, 0)
b, err := pathExists(templates.TemplateFS, path)
if err != nil {
return nil, err
}
if !b {
return files, nil
}
err = fs.WalkDir(templates.TemplateFS, path, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if strings.HasSuffix(path, suffix) {
files = append(files, path)
}
return nil
})
return files, err
}
func pathExists(fs fs.FS, path string) (bool, error) {
_, err := fs.Open(path)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
return true, err
}

View File

@@ -2,12 +2,14 @@ package session
import (
"context"
"database/sql"
"net/http"
"time"
"management/internal/pkg/config"
"github.com/alexedwards/scs/postgresstore"
"github.com/alexedwards/scs/v2"
"gorm.io/gorm"
)
type Session interface {
@@ -23,7 +25,7 @@ type session struct {
sessionManager *scs.SessionManager
}
func New(db *sql.DB, prod bool) Session {
func New(db *gorm.DB, config *config.Config) Session {
sessionManager := scs.New()
sessionManager.Lifetime = 24 * time.Hour
sessionManager.IdleTimeout = 2 * time.Hour
@@ -31,11 +33,12 @@ func New(db *sql.DB, prod bool) Session {
sessionManager.Cookie.HttpOnly = true
sessionManager.Cookie.Persist = true
sessionManager.Cookie.SameSite = http.SameSiteStrictMode
sessionManager.Cookie.Secure = prod
sessionManager.Cookie.Secure = config.App.Prod
sqlDB, _ := db.DB()
// postgres
// github.com/alexedwards/scs/postgresstore
sessionManager.Store = postgresstore.New(db)
sessionManager.Store = postgresstore.New(sqlDB)
// pgx
// github.com/alexedwards/scs/pgxstore
// sessionManager.Store = pgxstore.New(pool)

View File

@@ -2,7 +2,7 @@ package sliceutil
import "sort"
// 使用 map 去除重复元素
// RemoveDuplicatesWithMap 使用 map 去除重复元素
func RemoveDuplicatesWithMap[T comparable](slice []T) []T {
result := make([]T, 0, len(slice))
seen := make(map[T]bool)
@@ -15,7 +15,7 @@ func RemoveDuplicatesWithMap[T comparable](slice []T) []T {
return result
}
// 先排序再去重
// RemoveDuplicatesWithSort 先排序再去重
func RemoveDuplicatesWithSort(slice []int) []int {
if len(slice) == 0 {
return slice

View File

@@ -1,41 +0,0 @@
package smb
import (
"io/fs"
"net"
"management/internal/pkg/config"
"github.com/hirochachacha/go-smb2"
)
var FS fs.FS
func Init() error {
conn, err := net.Dial("tcp", config.File.Smb.Host+":445")
if err != nil {
return err
}
defer conn.Close()
d := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: config.File.Smb.Name,
Password: config.File.Smb.Pass,
},
}
s, err := d.Dial(conn)
if err != nil {
return err
}
defer s.Logoff()
fs, err := s.Mount(config.File.Smb.Mount)
if err != nil {
return err
}
defer fs.Umount()
FS = fs.DirFS(".")
return nil
}

View File

@@ -1,20 +0,0 @@
package snowflake
import (
"github.com/bwmarrin/snowflake"
)
var node *snowflake.Node
func Init() error {
n, err := snowflake.NewNode(1)
if err != nil {
return err
}
node = n
return nil
}
func GetId() int64 {
return node.Generate().Int64()
}

View File

@@ -1,25 +0,0 @@
package sqids
import "github.com/sqids/sqids-go"
var engine *sqids.Sqids
func Init() error {
var err error
engine, err = sqids.New(sqids.Options{
MinLength: 12,
Alphabet: "AvjM1lkB8N6cuhs2oFxnXyYDwCmLGI7JOzt9fW3HRgb5ZQrqaU04TePSVKdpiE",
})
if err != nil {
return err
}
return nil
}
func Encode(ids []uint64) (string, error) {
return engine.Encode(ids)
}
func Decode(s string) []uint64 {
return engine.Decode(s)
}

View File

@@ -1,59 +0,0 @@
package token
import (
"fmt"
"time"
tk "github.com/golang-jwt/jwt/v5"
)
const minSecretKeySize = 32
// JWTMaker JSON Web Token
type JWTMaker struct {
secretKey string
}
// NewJWTMaker 创建一个新的JWTMaker
func NewJWTMaker(secretKey string) error {
if len(secretKey) < minSecretKeySize {
return fmt.Errorf("invalid key size: must be at least %d characters", minSecretKeySize)
}
engine = &JWTMaker{secretKey}
return nil
}
// CreateToken 根据用户名和时间创建一个新的token
func (maker *JWTMaker) CreateToken(id string, username string, duration time.Duration) (string, *Payload, error) {
payload, err := NewPayload(id, username, duration)
if err != nil {
return "", payload, err
}
jwtToken := tk.NewWithClaims(tk.SigningMethodHS256, payload)
token, err := jwtToken.SignedString([]byte(maker.secretKey))
return token, payload, err
}
// VerifyToken checks if the token is valid or not
func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) {
keyFunc := func(token *tk.Token) (interface{}, error) {
_, ok := token.Method.(*tk.SigningMethodHMAC)
if !ok {
return nil, ErrInvalidToken
}
return []byte(maker.secretKey), nil
}
jwtToken, err := tk.ParseWithClaims(token, &Payload{}, keyFunc)
if err != nil {
return nil, ErrInvalidToken
}
payload, ok := jwtToken.Claims.(*Payload)
if !ok {
return nil, ErrInvalidToken
}
return payload, nil
}

View File

@@ -1,24 +0,0 @@
package token
import (
"time"
)
var engine Maker
// Maker 管理token的接口定义
type Maker interface {
// CreateToken 根据用户名和时间创建一个新的token
CreateToken(id string, username string, duration time.Duration) (string, *Payload, error)
// VerifyToken 校验token是否正确
VerifyToken(token string) (*Payload, error)
}
func CreateToken(id string, username string, duration time.Duration) (string, *Payload, error) {
return engine.CreateToken(id, username, duration)
}
func VerifyToken(token string) (*Payload, error) {
return engine.VerifyToken(token)
}

View File

@@ -1,60 +0,0 @@
package token
import (
"fmt"
"time"
"management/internal/pkg/config"
"github.com/aead/chacha20poly1305"
"github.com/o1egl/paseto"
)
// PasetoMaker is a PASETO token maker
type PasetoMaker struct {
paseto *paseto.V2
symmetricKey []byte
}
// NewPasetoMaker creates a new PasetoMaker
func NewPasetoMaker() error {
symmetricKey := config.File.JWT.SigningKey
if len(symmetricKey) != chacha20poly1305.KeySize {
return fmt.Errorf("invalid key size: must be exactly %d characters", chacha20poly1305.KeySize)
}
engine = &PasetoMaker{
paseto: paseto.NewV2(),
symmetricKey: []byte(symmetricKey),
}
return nil
}
// CreateToken creates a new token for a specific username and duration
func (maker *PasetoMaker) CreateToken(id string, username string, duration time.Duration) (string, *Payload, error) {
payload, err := NewPayload(id, username, duration)
if err != nil {
return "", payload, err
}
token, err := maker.paseto.Encrypt(maker.symmetricKey, payload, nil)
return token, payload, err
}
// VerifyToken checks if the token is valid or not
func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) {
payload := &Payload{}
err := maker.paseto.Decrypt(token, maker.symmetricKey, payload, nil)
if err != nil {
return nil, ErrInvalidToken
}
err = payload.Valid()
if err != nil {
return nil, err
}
return payload, nil
}

View File

@@ -1,43 +0,0 @@
package token
import (
"errors"
"time"
tk "github.com/golang-jwt/jwt/v5"
)
// Different types of error returned by the VerifyToken function
var (
ErrInvalidToken = errors.New("token is invalid")
ErrExpiredToken = errors.New("token has expired")
)
// Payload contains the payload data of the token
type Payload struct {
ID string `json:"id"`
Username string `json:"username"`
tk.RegisteredClaims // v5版本新加的方法
}
// NewPayload creates a new token payload with a specific username and duration
func NewPayload(id string, username string, duration time.Duration) (*Payload, error) {
payload := &Payload{
id,
username,
tk.RegisteredClaims{
ExpiresAt: tk.NewNumericDate(time.Now().Add(duration)), // 过期时间24小时
IssuedAt: tk.NewNumericDate(time.Now()), // 签发时间
NotBefore: tk.NewNumericDate(time.Now()), // 生效时间
},
}
return payload, nil
}
// Valid checks if the token payload is valid or not
func (payload *Payload) Valid() error {
if time.Now().After(payload.ExpiresAt.Time) {
return ErrExpiredToken
}
return nil
}

View File

@@ -1,48 +0,0 @@
package tpl
import (
"bytes"
"net/http"
"path/filepath"
"strings"
"management/internal/erpserver/model/dto"
)
type TemplateConfig struct {
Root string
Extension string
Layout string
Partial string
}
type HtmlData struct {
IsAuthenticated bool
AuthorizeUser dto.AuthorizeUser
AuthorizeMenus []*dto.OwnerMenuDto
Data any
}
func (r *render) HTML(w http.ResponseWriter, req *http.Request, tpl string, data map[string]any) {
name := strings.ReplaceAll(tpl, "/", "_")
t, ok := r.templates[name]
if !ok {
http.Error(w, "template is empty", http.StatusInternalServerError)
return
}
hd := r.setDefaultData(req, data)
buf := new(bytes.Buffer)
err := t.ExecuteTemplate(buf, filepath.Base(tpl), hd)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, err = buf.WriteTo(w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

View File

@@ -1,46 +0,0 @@
package tpl
import (
"html/template"
"net/http"
v1 "management/internal/erpserver/service/v1"
"management/internal/pkg/session"
)
type Renderer interface {
HTML(w http.ResponseWriter, req *http.Request, name string, data map[string]any)
JSON(w http.ResponseWriter, data any)
JSONF(w http.ResponseWriter, success bool, message string)
JSONOK(w http.ResponseWriter, message string)
JSONERR(w http.ResponseWriter, message string)
}
type render struct {
session session.Session
config *TemplateConfig
templates map[string]*template.Template
menusvc v1.MenuService
}
func New(session session.Session, menusvc v1.MenuService) (Renderer, error) {
render := &render{
session: session,
menusvc: menusvc,
config: &TemplateConfig{
Root: ".",
Extension: ".tmpl",
Layout: "base",
Partial: "partial",
},
}
templates, err := render.createTemplateCache()
if err != nil {
return nil, err
}
render.templates = templates
return render, nil
}

View File

@@ -1,180 +0,0 @@
package tpl
import (
"context"
"encoding/json"
"fmt"
"html/template"
"io/fs"
"net/http"
"os"
"path/filepath"
"slices"
"strings"
"management/internal/erpserver/model/dto"
"management/internal/pkg/know"
templates "management/web/templates/manage"
"github.com/justinas/nosurf"
)
func (r *render) setDefaultData(req *http.Request, data map[string]any) map[string]any {
if data == nil {
data = make(map[string]any)
}
ctx := req.Context()
isAuth := r.session.Exists(ctx, know.StoreName)
data["IsAuthenticated"] = isAuth
if isAuth {
var authUser dto.AuthorizeUser
u := r.session.GetBytes(ctx, know.StoreName)
_ = json.Unmarshal(u, &authUser)
data["AuthorizeMenus"] = r.getCurrentPathBtns(ctx, authUser.RoleID, req.URL.Path)
}
token := nosurf.Token(req)
data["CsrfToken"] = token
data["CsrfTokenField"] = template.HTML(fmt.Sprintf(`<input type="hidden" name="csrf_token" value="%s" />`, token))
return data
}
func (r *render) getCurrentPathBtns(ctx context.Context, roleID int32, path string) []*dto.OwnerMenuDto {
var res []*dto.OwnerMenuDto
// 获取当前登陆角色的权限
menus, err := r.menusvc.ListByRoleIDToMap(ctx, roleID)
if err != nil {
return res
}
menu, ok := menus[path]
if !ok {
return res
}
for _, item := range menus {
if menu.IsList {
if item.ParentID == menu.ID || item.ID == menu.ID {
res = append(res, item)
}
} else {
if item.ParentID == menu.ParentID {
res = append(res, item)
}
}
}
return res
}
func (r *render) createTemplateCache() (map[string]*template.Template, error) {
cache := make(map[string]*template.Template)
pages, err := getFiles(r.config.Root, r.config.Extension)
if err != nil {
return nil, err
}
layoutAndPartial, err := r.getLayoutAndPartials()
if err != nil {
return nil, err
}
for _, page := range pages {
if strings.HasPrefix(page, "base") || strings.HasSuffix(page, "partial") {
continue
}
name := filepath.Base(page)
pathArr := strings.Split(page, "/")
dir := pathArr[len(pathArr)-2 : len(pathArr)-1]
templateName := fmt.Sprintf("%s_%s", dir[0], name)
ts := template.Must(template.New(templateName).Funcs(r.btnFuncs()).Funcs(r.Methods()).ParseFS(templates.TemplateFS, page))
if err != nil {
return nil, err
}
ts, err = ts.ParseFS(templates.TemplateFS, layoutAndPartial...)
if err != nil {
return nil, err
}
cache[templateName] = ts
}
return cache, nil
}
func (r *render) getLayoutAndPartials() ([]string, error) {
layouts, err := getFiles(r.config.Layout, r.config.Extension)
if err != nil {
return nil, err
}
partials, err := getFiles(r.config.Partial, r.config.Extension)
if err != nil {
return nil, err
}
return slices.Concat(layouts, partials), nil
}
func getFiles(path string, stuffix string) ([]string, error) {
files := make([]string, 0)
b, err := pathExists(templates.TemplateFS, path)
if err != nil {
return nil, err
}
if !b {
return files, nil
}
err = fs.WalkDir(templates.TemplateFS, path, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if strings.HasSuffix(path, stuffix) {
files = append(files, path)
}
return nil
})
// err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
// if info == nil {
// return err
// }
// if info.IsDir() {
// return nil
// }
// // 将模板后缀的文件放到列表
// if strings.HasSuffix(path, stuffix) {
// files = append(files, path)
// }
// return nil
// })
return files, err
}
func pathExists(fs fs.FS, path string) (bool, error) {
_, err := fs.Open(path)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
return true, err
}
func firstLower(s string) string {
if len(s) == 0 {
return s
}
return strings.ToLower(s[:1]) + s[1:]
}

View File

@@ -13,15 +13,15 @@ import (
// 初始化一个验证器实例
var validate = validator.New()
// 自定义验证规则requiredint
// 自定义验证规则
func init() {
validate.RegisterValidation("telephone", func(fl validator.FieldLevel) bool {
_ = validate.RegisterValidation("telephone", func(fl validator.FieldLevel) bool {
if err := IsValidPhone(fl.Field().String()); err != nil {
return false
}
return true
})
validate.RegisterValidation("dateonly", func(fl validator.FieldLevel) bool {
_ = validate.RegisterValidation("dateonly", func(fl validator.FieldLevel) bool {
_, err := time.ParseInLocation("2006-01-02", fl.Field().String(), time.Local)
if err != nil {
return false
@@ -33,7 +33,8 @@ func init() {
func ValidateForm(s any) error {
// 验证结构体数据
if err := validate.Struct(s); err != nil {
if _, ok := err.(*validator.InvalidValidationError); ok {
var invalidValidationError *validator.InvalidValidationError
if errors.As(err, &invalidValidationError) {
return errors.New("验证器配置错误")
}