From 38ee553cf3cdd6860a98df441fea2f5a859cd8be Mon Sep 17 00:00:00 2001 From: kenneth <1185230223@qq.com> Date: Wed, 7 May 2025 15:32:05 +0800 Subject: [PATCH] gorm update --- cmd/erp.go | 6 +- internal/erpserver/handler/handler.go | 35 ++-- internal/erpserver/handler/system/user.go | 14 +- internal/erpserver/http.go | 170 +++++++++--------- internal/erpserver/model/system/audit_log.go | 2 +- internal/erpserver/service/service.go | 4 +- internal/erpserver/service/v1/system/user.go | 13 +- internal/erpserver/wire.go | 4 +- internal/erpserver/wire_gen.go | 16 +- internal/pkg/file/upload.go | 2 +- internal/pkg/middleware/audit.go | 63 ++++--- internal/pkg/middleware/authorize.go | 119 +++++------- .../pkg/middleware/{nocsrf.go => csrf.go} | 2 +- internal/pkg/middleware/middleware.go | 35 ---- internal/pkg/middleware/session.go | 6 +- internal/pkg/render/html.go | 13 +- internal/pkg/render/render.go | 4 +- internal/pkg/session/session.go | 76 ++++---- 18 files changed, 283 insertions(+), 301 deletions(-) rename internal/pkg/middleware/{nocsrf.go => csrf.go} (62%) delete mode 100644 internal/pkg/middleware/middleware.go diff --git a/cmd/erp.go b/cmd/erp.go index 7743953..b389408 100644 --- a/cmd/erp.go +++ b/cmd/erp.go @@ -40,7 +40,7 @@ var erpCmd = &cobra.Command{ } func runErp() error { - config, err := config.New(configPath) + conf, err := config.New(configPath) if err != nil { return err } @@ -50,14 +50,14 @@ func runErp() error { return err } - mux, fn, err := erpserver.NewWire(config, l) + mux, fn, err := erpserver.NewWire(conf, l) if err != nil { return err } defer fn() - address := fmt.Sprintf("%s:%d", config.App.Host, config.App.Port) + address := fmt.Sprintf("%s:%d", conf.App.Host, conf.App.Port) log.Printf("Starting manage server on %s", address) if runtime.GOOS == "windows" { diff --git a/internal/erpserver/handler/handler.go b/internal/erpserver/handler/handler.go index f94b5ff..db09993 100644 --- a/internal/erpserver/handler/handler.go +++ b/internal/erpserver/handler/handler.go @@ -6,31 +6,32 @@ import ( "management/internal/erpserver/model/dto" "management/internal/pkg/config" - "management/internal/pkg/middleware" + "management/internal/pkg/know" "management/internal/pkg/render" + "management/internal/pkg/session" "github.com/drhin/logger" ) type Handler struct { - Config *config.Config - Log *logger.Logger - Middleware middleware.Middleware + Config *config.Config + Log *logger.Logger - render render.Render + session session.Manager + render render.Render } func NewHandler( config *config.Config, log *logger.Logger, - middleware middleware.Middleware, + session session.Manager, render render.Render, ) *Handler { return &Handler{ - Config: config, - Log: log, - Middleware: middleware, - render: render, + Config: config, + Log: log, + session: session, + render: render, } } @@ -38,7 +39,19 @@ func NewHandler( // middleware 帮助方法 func (h *Handler) AuthUser(ctx context.Context) dto.AuthorizeUser { - return h.Middleware.AuthUser(ctx) + u, err := h.session.GetUser(ctx, know.StoreName) + if err != nil { + return dto.AuthorizeUser{} + } + return *u +} + +func (h *Handler) RenewToken(ctx context.Context) error { + return h.session.RenewToken(ctx) +} + +func (h *Handler) Destroy(ctx context.Context) error { + return h.session.Destroy(ctx) } // ===================================================================================================================== diff --git a/internal/erpserver/handler/system/user.go b/internal/erpserver/handler/system/user.go index 921da2e..1d57ac9 100644 --- a/internal/erpserver/handler/system/user.go +++ b/internal/erpserver/handler/system/user.go @@ -167,11 +167,15 @@ func (h *UserHandler) Login(w http.ResponseWriter, r *http.Request) { ctx := r.Context() switch r.Method { case http.MethodGet: - if h.Middleware.IsAuth(ctx) && h.Middleware.RefreshToken(ctx) { - http.Redirect(w, r, "/home.html", http.StatusFound) - return + user := h.AuthUser(ctx) + if user.ID > 0 { + if err := h.RenewToken(ctx); err == nil { + http.Redirect(w, r, "/home.html", http.StatusFound) + return + } } - _ = h.Middleware.Destroy(ctx) + + _ = h.Destroy(ctx) h.HTML(w, r, "oauth/login.tmpl", nil) case http.MethodPost: defer func(Body io.ReadCloser) { @@ -207,7 +211,7 @@ func (h *UserHandler) Login(w http.ResponseWriter, r *http.Request) { } func (h *UserHandler) Logout(w http.ResponseWriter, r *http.Request) { - err := h.Middleware.Destroy(r.Context()) + err := h.Destroy(r.Context()) if err != nil { h.Log.Error(err.Error(), err) } diff --git a/internal/erpserver/http.go b/internal/erpserver/http.go index 0256b3e..19b92b0 100644 --- a/internal/erpserver/http.go +++ b/internal/erpserver/http.go @@ -5,14 +5,20 @@ import ( "management/internal/erpserver/handler/common" "management/internal/erpserver/handler/system" + v1 "management/internal/erpserver/service/v1" mi "management/internal/pkg/middleware" + "management/internal/pkg/session" + "github.com/drhin/logger" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" ) func NewHTTPServer( - mi mi.Middleware, + sm session.Manager, + log *logger.Logger, + menuService v1.MenuService, + auditLogService v1.AuditLogService, captchaHandler *common.CaptchaHandler, uploadHandler *common.UploadHandler, configHandler *system.ConfigHandler, @@ -34,106 +40,108 @@ func NewHTTPServer( staticServer := http.FileServer(http.Dir("./web/statics/")) r.Handle("/statics/*", http.StripPrefix("/statics", staticServer)) - uploadServer := http.FileServer(http.Dir("./upload/")) - r.Handle("/upload/*", http.StripPrefix("/upload", uploadServer)) + uploadServer := http.FileServer(http.Dir("./public/")) + r.Handle("/public/*", http.StripPrefix("/public", uploadServer)) r.Group(func(r chi.Router) { - r.Use(mi.NoSurf) // CSRF - r.Use(mi.LoadSession) // Session - - r.Get("/captcha", captchaHandler.Captcha) - - r.With(mi.Authorize, mi.Audit).Post("/upload/img", uploadHandler.Img) - r.With(mi.Authorize, mi.Audit).Post("/upload/file", uploadHandler.File) - r.With(mi.Authorize, mi.Audit).Post("/upload/multi_files", uploadHandler.MultiFiles) + r.Use(mi.NoSurf) // CSRF + r.Use(mi.LoadSession(sm)) // Session r.Get("/", userHandler.Login) + r.Get("/captcha", captchaHandler.Captcha) r.Post("/login", userHandler.Login) - r.Get("/logout", userHandler.Logout) - r.With(mi.Authorize).Get("/home.html", homeHandler.Home) - r.With(mi.Authorize).Get("/dashboard", homeHandler.Dashboard) + r.Group(func(r chi.Router) { + r.Use(mi.Authorize(sm, menuService)) - r.With(mi.Authorize).Get("/pear.json", configHandler.Pear) + r.Get("/logout", userHandler.Logout) - r.Route("/system", func(r chi.Router) { - r.Use(mi.Authorize) + r.Get("/home.html", homeHandler.Home) + r.Get("/dashboard", homeHandler.Dashboard) + r.Get("/pear.json", configHandler.Pear) - r.Route("/config", func(r chi.Router) { - r.Use(mi.Audit) - r.Get("/list", configHandler.List) - r.Post("/list", configHandler.List) - r.Get("/add", configHandler.Add) - r.Get("/edit", configHandler.Edit) - r.Post("/save", configHandler.Save) - r.Post("/refresh_cache", configHandler.RefreshCache) - r.Post("/reset_pear", configHandler.ResetPear) + r.Route("/upload", func(r chi.Router) { + r.Use(mi.Audit(sm, auditLogService, log)) + r.Get("/img", uploadHandler.Img) + r.Get("/file", uploadHandler.File) + r.Get("/multi_files", uploadHandler.MultiFiles) }) - r.Get("/menus", menuHandler.Menus) - r.Route("/menu", func(r chi.Router) { - r.Use(mi.Audit) - r.Get("/list", menuHandler.List) - r.Post("/list", menuHandler.List) - r.Get("/add", menuHandler.Add) - r.Get("/add_children", menuHandler.AddChildren) - r.Get("/edit", menuHandler.Edit) - r.Post("/save", menuHandler.Save) - r.Post("/data", menuHandler.Data) - r.Post("/refresh_cache", menuHandler.RefreshCache) - }) + r.Route("/system", func(r chi.Router) { + r.Use(mi.Audit(sm, auditLogService, log)) - r.Route("/department", func(r chi.Router) { - r.Use(mi.Audit) + r.Get("/menus", menuHandler.Menus) - r.Get("/list", departmentHandler.List) - r.Post("/list", departmentHandler.List) - r.Get("/add", departmentHandler.Add) - r.Get("/add_children", departmentHandler.AddChildren) - r.Get("/edit", departmentHandler.Edit) - r.Post("/save", departmentHandler.Save) - r.Post("/data", departmentHandler.Data) - r.Post("/refresh_cache", departmentHandler.RefreshCache) - r.Post("/rebuild_parent_path", departmentHandler.RebuildParentPath) - }) + r.Route("/config", func(r chi.Router) { + r.Get("/list", configHandler.List) + r.Post("/list", configHandler.List) + r.Get("/add", configHandler.Add) + r.Get("/edit", configHandler.Edit) + r.Post("/save", configHandler.Save) + r.Post("/refresh_cache", configHandler.RefreshCache) + r.Post("/reset_pear", configHandler.ResetPear) + }) - r.Route("/role", func(r chi.Router) { - r.Use(mi.Audit) + r.Route("/menu", func(r chi.Router) { + r.Get("/list", menuHandler.List) + r.Post("/list", menuHandler.List) + r.Get("/add", menuHandler.Add) + r.Get("/add_children", menuHandler.AddChildren) + r.Get("/edit", menuHandler.Edit) + r.Post("/save", menuHandler.Save) + r.Post("/data", menuHandler.Data) + r.Post("/refresh_cache", menuHandler.RefreshCache) + }) - r.Get("/list", roleHandler.List) - r.Post("/list", roleHandler.List) - r.Get("/add", roleHandler.Add) - r.Get("/add_children", roleHandler.AddChildren) - r.Get("/edit", roleHandler.Edit) - r.Post("/save", roleHandler.Save) - r.Post("/data", roleHandler.Data) - r.Post("/refresh_cache", roleHandler.RefreshCache) - r.Post("/rebuild_parent_path", roleHandler.RebuildParentPath) - r.Post("/refresh_role_menus", roleHandler.RefreshRoleMenus) - r.Get("/set_menu", roleHandler.SetMenu) - r.Post("/set_menu", roleHandler.SetMenu) - }) + r.Route("/department", func(r chi.Router) { + r.Get("/list", departmentHandler.List) + r.Post("/list", departmentHandler.List) + r.Get("/add", departmentHandler.Add) + r.Get("/add_children", departmentHandler.AddChildren) + r.Get("/edit", departmentHandler.Edit) + r.Post("/save", departmentHandler.Save) + r.Post("/data", departmentHandler.Data) + r.Post("/refresh_cache", departmentHandler.RefreshCache) + r.Post("/rebuild_parent_path", departmentHandler.RebuildParentPath) + }) - r.Route("/user", func(r chi.Router) { - r.Get("/list", userHandler.List) - r.Post("/list", userHandler.List) - r.Get("/add", userHandler.Add) - r.Get("/edit", userHandler.Edit) - r.Post("/save", userHandler.Save) - r.Get("/profile", userHandler.Profile) - r.Post("/data", userHandler.Data) - }) + r.Route("/role", func(r chi.Router) { + r.Get("/list", roleHandler.List) + r.Post("/list", roleHandler.List) + r.Get("/add", roleHandler.Add) + r.Get("/add_children", roleHandler.AddChildren) + r.Get("/edit", roleHandler.Edit) + r.Post("/save", roleHandler.Save) + r.Post("/data", roleHandler.Data) + r.Post("/refresh_cache", roleHandler.RefreshCache) + r.Post("/rebuild_parent_path", roleHandler.RebuildParentPath) + r.Post("/refresh_role_menus", roleHandler.RefreshRoleMenus) + r.Get("/set_menu", roleHandler.SetMenu) + r.Post("/set_menu", roleHandler.SetMenu) + }) - r.Route("/login_log", func(r chi.Router) { - r.Get("/list", loginLogHandler.List) - r.Post("/list", loginLogHandler.List) - }) + r.Route("/user", func(r chi.Router) { + r.Get("/list", userHandler.List) + r.Post("/list", userHandler.List) + r.Get("/add", userHandler.Add) + r.Get("/edit", userHandler.Edit) + r.Post("/save", userHandler.Save) + r.Get("/profile", userHandler.Profile) + r.Post("/data", userHandler.Data) + }) - r.Route("/audit_log", func(r chi.Router) { - r.Get("/list", auditHandler.List) - r.Post("/list", auditHandler.List) + r.Route("/login_log", func(r chi.Router) { + r.Get("/list", loginLogHandler.List) + r.Post("/list", loginLogHandler.List) + }) + + r.Route("/audit_log", func(r chi.Router) { + r.Get("/list", auditHandler.List) + r.Post("/list", auditHandler.List) + }) }) }) + }) return r diff --git a/internal/erpserver/model/system/audit_log.go b/internal/erpserver/model/system/audit_log.go index db51663..d284f34 100644 --- a/internal/erpserver/model/system/audit_log.go +++ b/internal/erpserver/model/system/audit_log.go @@ -45,7 +45,7 @@ func NewAuditLog(r *http.Request, email, os, browser string, start, end time.Tim contentType := r.Header.Get("Content-Type") if strings.Contains(contentType, "application/json") { body := make([]byte, r.ContentLength) - r.Body.Read(body) + _, _ = r.Body.Read(body) params = string(body) } else if strings.Contains(contentType, "application/x-www-form-urlencoded") { params = r.Form.Encode() diff --git a/internal/erpserver/service/service.go b/internal/erpserver/service/service.go index d0f0f74..5a812fb 100644 --- a/internal/erpserver/service/service.go +++ b/internal/erpserver/service/service.go @@ -11,14 +11,14 @@ import ( type Service struct { Log *logger.Logger Tx repository.Transaction - Session session.Session + Session session.Manager Redis redis.Cache } func NewService( log *logger.Logger, tx repository.Transaction, - session session.Session, + session session.Manager, redis redis.Cache, ) *Service { return &Service{ diff --git a/internal/erpserver/service/v1/system/user.go b/internal/erpserver/service/v1/system/user.go index bd32081..8251e59 100644 --- a/internal/erpserver/service/v1/system/user.go +++ b/internal/erpserver/service/v1/system/user.go @@ -2,7 +2,6 @@ package system import ( "context" - "encoding/json" "errors" "strconv" "time" @@ -181,7 +180,7 @@ func (s *userService) login(ctx context.Context, req *form.Login) error { } func (s *userService) loginSuccess(ctx context.Context, user *system.User, req *form.Login) error { - auth := dto.AuthorizeUser{ + return s.Session.PutUser(ctx, know.StoreName, &dto.AuthorizeUser{ ID: user.ID, Uuid: user.Uuid, Email: user.Email, @@ -191,13 +190,5 @@ func (s *userService) loginSuccess(ctx context.Context, user *system.User, req * OS: req.Os, IP: req.Ip, Browser: req.Browser, - } - - gob, err := json.Marshal(auth) - if err != nil { - return err - } - - s.Session.Put(ctx, know.StoreName, gob) - return nil + }) } diff --git a/internal/erpserver/wire.go b/internal/erpserver/wire.go index 0025131..940e939 100644 --- a/internal/erpserver/wire.go +++ b/internal/erpserver/wire.go @@ -13,7 +13,6 @@ import ( commonService "management/internal/erpserver/service/v1/common" systemService "management/internal/erpserver/service/v1/system" "management/internal/pkg/config" - "management/internal/pkg/middleware" "management/internal/pkg/redis" "management/internal/pkg/render" "management/internal/pkg/session" @@ -72,9 +71,8 @@ func NewWire(*config.Config, *logger.Logger) (*chi.Mux, func(), error) { panic(wire.Build( repositorySet, redis.New, - session.New, + session.NewSCSManager, serviceSet, - middleware.New, render.New, handlerSet, serverSet, diff --git a/internal/erpserver/wire_gen.go b/internal/erpserver/wire_gen.go index 38a39b4..796623e 100644 --- a/internal/erpserver/wire_gen.go +++ b/internal/erpserver/wire_gen.go @@ -19,7 +19,6 @@ import ( "management/internal/erpserver/service/v1/common" system2 "management/internal/erpserver/service/v1/system" "management/internal/pkg/config" - "management/internal/pkg/middleware" "management/internal/pkg/redis" "management/internal/pkg/render" "management/internal/pkg/session" @@ -32,7 +31,11 @@ func NewWire(configConfig *config.Config, loggerLogger *logger.Logger) (*chi.Mux if err != nil { return nil, nil, err } - sessionSession := session.New(db, configConfig) + manager, err := session.NewSCSManager(db, configConfig) + if err != nil { + cleanup() + return nil, nil, err + } repositoryRepository := repository.NewRepository(db, loggerLogger) transaction := repository.NewTransaction(repositoryRepository) cache, cleanup2, err := redis.New(configConfig, loggerLogger) @@ -40,7 +43,7 @@ func NewWire(configConfig *config.Config, loggerLogger *logger.Logger) (*chi.Mux cleanup() return nil, nil, err } - serviceService := service.NewService(loggerLogger, transaction, sessionSession, cache) + serviceService := service.NewService(loggerLogger, transaction, manager, cache) menuRepository := system.NewMenuRepository(repositoryRepository) roleRepository := system.NewRoleRepository(repositoryRepository) roleService := system2.NewRoleService(serviceService, roleRepository) @@ -49,14 +52,13 @@ func NewWire(configConfig *config.Config, loggerLogger *logger.Logger) (*chi.Mux menuService := system2.NewMenuService(serviceService, menuRepository, roleService, roleMenuService) auditLogRepository := system.NewAuditLogRepository(repositoryRepository) auditLogService := system2.NewAuditLogService(serviceService, auditLogRepository) - middlewareMiddleware := middleware.New(sessionSession, menuService, auditLogService) - renderRender, err := render.New(sessionSession, menuService) + renderRender, err := render.New(manager, menuService) if err != nil { cleanup2() cleanup() return nil, nil, err } - handlerHandler := handler.NewHandler(configConfig, loggerLogger, middlewareMiddleware, renderRender) + handlerHandler := handler.NewHandler(configConfig, loggerLogger, manager, renderRender) captchaService := common.NewCaptchaService() captchaHandler := common2.NewCaptchaHandler(handlerHandler, captchaService) uploadHandler := common2.NewUploadHandler(handlerHandler) @@ -76,7 +78,7 @@ func NewWire(configConfig *config.Config, loggerLogger *logger.Logger) (*chi.Mux menuHandler := system3.NewMenuHandler(handlerHandler, menuService) roleHandler := system3.NewRoleHandler(handlerHandler, roleService, menuService) departmentHandler := system3.NewDepartmentHandler(handlerHandler, departmentService) - mux := NewHTTPServer(middlewareMiddleware, captchaHandler, uploadHandler, configHandler, homeHandler, userHandler, loginLogHandler, auditHandler, menuHandler, roleHandler, departmentHandler) + mux := NewHTTPServer(manager, loggerLogger, menuService, auditLogService, captchaHandler, uploadHandler, configHandler, homeHandler, userHandler, loginLogHandler, auditHandler, menuHandler, roleHandler, departmentHandler) return mux, func() { cleanup2() cleanup() diff --git a/internal/pkg/file/upload.go b/internal/pkg/file/upload.go index 9985793..808aa4a 100644 --- a/internal/pkg/file/upload.go +++ b/internal/pkg/file/upload.go @@ -150,7 +150,7 @@ func UploadFile(file *multipart.FileHeader, t FileType) (string, error) { } func GetPath() string { - return fmt.Sprintf("upload/%s/%s/%s/", time.Now().Format("2006"), time.Now().Format("01"), time.Now().Format("02")) + return fmt.Sprintf("public/%s/%s/%s/", time.Now().Format("2006"), time.Now().Format("01"), time.Now().Format("02")) } func GenFilename(ext string) string { diff --git a/internal/pkg/middleware/audit.go b/internal/pkg/middleware/audit.go index 9639137..e74d15c 100644 --- a/internal/pkg/middleware/audit.go +++ b/internal/pkg/middleware/audit.go @@ -1,33 +1,52 @@ package middleware import ( - "context" + "errors" "net/http" "time" systemmodel "management/internal/erpserver/model/system" + v1 "management/internal/erpserver/service/v1" + "management/internal/pkg/know" + "management/internal/pkg/session" + + "github.com/drhin/logger" + "go.uber.org/zap" ) -func (m *middleware) Audit(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - defer func(res http.ResponseWriter, req *http.Request) { - // 记录审计日志 - go m.writeLog(req, start) - }(w, r) - next.ServeHTTP(w, r) +func Audit(sess session.Manager, auditLogService v1.AuditLogService, log *logger.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + defer func() { + go func() { + ctx := r.Context() + user, err := sess.GetUser(ctx, know.StoreName) + if err != nil { + log.Error(err.Error(), err) + return + } + + if user.ID == 0 { + log.Error("scs get user is empty", errors.New("scs get user is empty")) + return + } + + al := systemmodel.NewAuditLog(r, user.Email, user.OS, user.Browser, start, time.Now()) + if err := auditLogService.Create(ctx, al); err != nil { + log.Error(err.Error(), err, + zap.Int32("user_id", user.ID), + zap.String("user", user.Email), + zap.String("ip", al.Ip), + zap.String("os", al.Os), + zap.String("method", al.Method), + zap.String("path", al.Url), + ) + } + }() + }() + + next.ServeHTTP(w, r) + }) } - - return http.HandlerFunc(fn) -} - -func (m *middleware) writeLog(req *http.Request, start time.Time) { - end := time.Now() - user := m.AuthUser(req.Context()) - al := systemmodel.NewAuditLog(req, user.Email, user.OS, user.Browser, start, end) - - c, cancel := context.WithTimeout(context.Background(), time.Second*3) - defer cancel() - - _ = m.auditLogService.Create(c, al) } diff --git a/internal/pkg/middleware/authorize.go b/internal/pkg/middleware/authorize.go index f7ff130..14e3aa4 100644 --- a/internal/pkg/middleware/authorize.go +++ b/internal/pkg/middleware/authorize.go @@ -1,94 +1,59 @@ package middleware import ( - "context" - "encoding/json" "net/http" "management/internal/erpserver/model/dto" + v1 "management/internal/erpserver/service/v1" "management/internal/pkg/know" + "management/internal/pkg/session" ) -var defaultMenus = map[string]bool{ - "/home.html": true, - "/dashboard": true, - "/system/menus": true, - "/upload/img": true, - "/upload/file": true, - "/upload/mutilfile": true, - "/pear.json": true, +var publicRoutes = map[string]bool{ + "/home.html": true, + "/dashboard": true, + "/system/menus": true, + "/upload/img": true, + "/upload/file": true, + "/upload/multi_files": true, + "/pear.json": true, } -func (m *middleware) Authorize(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - user, ok := m.isLogin(ctx) - if !ok { - http.Redirect(w, r, "/", http.StatusFound) - return - } +func Authorize( + sess session.Manager, + menuService v1.MenuService, +) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + path := r.URL.Path + + // 登陆检查 + user, err := sess.GetUser(ctx, know.StoreName) + if err != nil || user.ID == 0 { + http.Redirect(w, r, "/", http.StatusFound) + return + } + + // 公共路由放行 + if publicRoutes[path] { + next.ServeHTTP(w, r) + return + } + + // 权限检查 + menus, err := menuService.ListByRoleIDToMap(ctx, user.RoleID) + if err != nil || !hasPermission(menus, path) { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } - // 登陆成功 判断权限 - path := r.URL.Path - if b, ok := defaultMenus[path]; ok && b { next.ServeHTTP(w, r) - return - } - - menus, err := m.menuService.ListByRoleIDToMap(ctx, user.RoleID) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - if _, ok := menus[path]; ok { - next.ServeHTTP(w, r) - return - } - - http.Error(w, "Unauthorized", http.StatusUnauthorized) + }) } - - return http.HandlerFunc(fn) } -func (m *middleware) isLogin(ctx context.Context) (*dto.AuthorizeUser, bool) { - if exists := m.session.Exists(ctx, know.StoreName); exists { - b := m.session.GetBytes(ctx, know.StoreName) - var user dto.AuthorizeUser - if err := json.Unmarshal(b, &user); err == nil && user.ID > 0 { - return &user, true - } - } - - return nil, false -} - -func (m *middleware) AuthUser(ctx context.Context) dto.AuthorizeUser { - var user dto.AuthorizeUser - if exists := m.session.Exists(ctx, know.StoreName); exists { - b := m.session.GetBytes(ctx, know.StoreName) - _ = json.Unmarshal(b, &user) - } - return user -} - -func (m *middleware) IsAuth(ctx context.Context) bool { - var user dto.AuthorizeUser - b := m.session.GetBytes(ctx, know.StoreName) - if err := json.Unmarshal(b, &user); err == nil && user.ID > 0 { - return true - } - return false -} - -func (m *middleware) RefreshToken(ctx context.Context) bool { - if err := m.session.RenewToken(ctx); err == nil { - return true - } - return false -} - -func (m *middleware) Destroy(ctx context.Context) error { - return m.session.Destroy(ctx) +func hasPermission(menus map[string]*dto.OwnerMenuDto, path string) bool { + _, ok := menus[path] + return ok } diff --git a/internal/pkg/middleware/nocsrf.go b/internal/pkg/middleware/csrf.go similarity index 62% rename from internal/pkg/middleware/nocsrf.go rename to internal/pkg/middleware/csrf.go index 2b098a8..4722fc5 100644 --- a/internal/pkg/middleware/nocsrf.go +++ b/internal/pkg/middleware/csrf.go @@ -6,6 +6,6 @@ import ( "github.com/justinas/nosurf" ) -func (m *middleware) NoSurf(next http.Handler) http.Handler { +func NoSurf(next http.Handler) http.Handler { return nosurf.New(next) } diff --git a/internal/pkg/middleware/middleware.go b/internal/pkg/middleware/middleware.go deleted file mode 100644 index 01817d6..0000000 --- a/internal/pkg/middleware/middleware.go +++ /dev/null @@ -1,35 +0,0 @@ -package middleware - -import ( - "context" - "net/http" - - "management/internal/erpserver/model/dto" - v1 "management/internal/erpserver/service/v1" - "management/internal/pkg/session" -) - -type Middleware interface { - Audit(next http.Handler) http.Handler - NoSurf(next http.Handler) http.Handler - LoadSession(next http.Handler) http.Handler - Authorize(next http.Handler) http.Handler - AuthUser(ctx context.Context) dto.AuthorizeUser - IsAuth(ctx context.Context) bool - RefreshToken(ctx context.Context) bool - Destroy(ctx context.Context) error -} - -type middleware struct { - session session.Session - menuService v1.MenuService - auditLogService v1.AuditLogService -} - -func New(session session.Session, menuService v1.MenuService, auditLogService v1.AuditLogService) Middleware { - return &middleware{ - session: session, - menuService: menuService, - auditLogService: auditLogService, - } -} diff --git a/internal/pkg/middleware/session.go b/internal/pkg/middleware/session.go index 07cf596..382d2e9 100644 --- a/internal/pkg/middleware/session.go +++ b/internal/pkg/middleware/session.go @@ -2,8 +2,10 @@ package middleware import ( "net/http" + + "management/internal/pkg/session" ) -func (m *middleware) LoadSession(next http.Handler) http.Handler { - return m.session.LoadAndSave(next) +func LoadSession(sm session.Manager) func(http.Handler) http.Handler { + return sm.Load } diff --git a/internal/pkg/render/html.go b/internal/pkg/render/html.go index 6fd96c6..2da8a93 100644 --- a/internal/pkg/render/html.go +++ b/internal/pkg/render/html.go @@ -3,7 +3,6 @@ package render import ( "bytes" "context" - "encoding/json" "fmt" "html/template" "net/http" @@ -60,15 +59,15 @@ func (r *render) setDefaultData(req *http.Request, data map[string]any) map[stri } ctx := req.Context() - isAuth := r.session.Exists(ctx, know.StoreName) - data["IsAuthenticated"] = isAuth - if isAuth { - var authUser dto.AuthorizeUser - u := r.session.GetBytes(ctx, know.StoreName) - _ = json.Unmarshal(u, &authUser) + authUser, err := r.session.GetUser(ctx, know.StoreName) + if err != nil || authUser == nil { + data["IsAuthenticated"] = false + } else { + data["IsAuthenticated"] = true data["AuthorizeMenus"] = r.getCurrentPathButtons(ctx, authUser.RoleID, req.URL.Path) } + token := nosurf.Token(req) data["CsrfToken"] = token data["CsrfTokenField"] = template.HTML(fmt.Sprintf(``, token)) diff --git a/internal/pkg/render/render.go b/internal/pkg/render/render.go index 5dd0e0f..f10c5f1 100644 --- a/internal/pkg/render/render.go +++ b/internal/pkg/render/render.go @@ -28,12 +28,12 @@ type jsonRender interface { type render struct { templateConfig *TemplateConfig templates map[string]*template.Template - session session.Session + session session.Manager menuService v1.MenuService } -func New(session session.Session, menuService v1.MenuService) (Render, error) { +func New(session session.Manager, menuService v1.MenuService) (Render, error) { r := &render{ templateConfig: &TemplateConfig{ Root: ".", diff --git a/internal/pkg/session/session.go b/internal/pkg/session/session.go index d5ecb60..12b645f 100644 --- a/internal/pkg/session/session.go +++ b/internal/pkg/session/session.go @@ -2,9 +2,12 @@ package session import ( "context" + "encoding/json" + "errors" "net/http" "time" + "management/internal/erpserver/model/dto" "management/internal/pkg/config" "github.com/alexedwards/scs/postgresstore" @@ -12,20 +15,22 @@ import ( "gorm.io/gorm" ) -type Session interface { - Destroy(ctx context.Context) error - LoadAndSave(next http.Handler) http.Handler - Put(ctx context.Context, key string, val any) - GetBytes(ctx context.Context, key string) []byte - Exists(ctx context.Context, key string) bool +var ErrNoSession = errors.New("session user not found") + +// Manager 抽象核心会话操作 +type Manager interface { + Load(next http.Handler) http.Handler + GetUser(ctx context.Context, key string) (*dto.AuthorizeUser, error) + PutUser(ctx context.Context, key string, user *dto.AuthorizeUser) error RenewToken(ctx context.Context) error + Destroy(ctx context.Context) error } -type session struct { - sessionManager *scs.SessionManager +type SCSSession struct { + manager *scs.SessionManager } -func New(db *gorm.DB, config *config.Config) Session { +func NewSCSManager(db *gorm.DB, config *config.Config) (Manager, error) { sessionManager := scs.New() sessionManager.Lifetime = 24 * time.Hour sessionManager.IdleTimeout = 2 * time.Hour @@ -35,7 +40,11 @@ func New(db *gorm.DB, config *config.Config) Session { sessionManager.Cookie.SameSite = http.SameSiteStrictMode sessionManager.Cookie.Secure = config.App.Prod - sqlDB, _ := db.DB() + sqlDB, err := db.DB() + if err != nil { + return nil, err + } + // postgres // github.com/alexedwards/scs/postgresstore sessionManager.Store = postgresstore.New(sqlDB) @@ -44,32 +53,39 @@ func New(db *gorm.DB, config *config.Config) Session { // sessionManager.Store = pgxstore.New(pool) // redis // sessionManager.Store = newRedisStore() + return &SCSSession{manager: sessionManager}, nil +} - return &session{ - sessionManager: sessionManager, +func (s *SCSSession) Load(next http.Handler) http.Handler { + return s.manager.LoadAndSave(next) +} + +func (s *SCSSession) GetUser(ctx context.Context, key string) (*dto.AuthorizeUser, error) { + data, ok := s.manager.Get(ctx, key).([]byte) + if !ok || len(data) == 0 { + return nil, ErrNoSession } + + var user dto.AuthorizeUser + if err := json.Unmarshal(data, &user); err != nil { + return nil, err + } + return &user, nil } -func (s *session) Destroy(ctx context.Context) error { - return s.sessionManager.Destroy(ctx) +func (s *SCSSession) PutUser(ctx context.Context, key string, user *dto.AuthorizeUser) error { + data, err := json.Marshal(user) + if err != nil { + return err + } + s.manager.Put(ctx, key, data) + return nil } -func (s *session) LoadAndSave(next http.Handler) http.Handler { - return s.sessionManager.LoadAndSave(next) +func (s *SCSSession) RenewToken(ctx context.Context) error { + return s.manager.RenewToken(ctx) } -func (s *session) Put(ctx context.Context, key string, val any) { - s.sessionManager.Put(ctx, key, val) -} - -func (s *session) GetBytes(ctx context.Context, key string) []byte { - return s.sessionManager.GetBytes(ctx, key) -} - -func (s *session) Exists(ctx context.Context, key string) bool { - return s.sessionManager.Exists(ctx, key) -} - -func (s *session) RenewToken(ctx context.Context) error { - return s.sessionManager.RenewToken(ctx) +func (s *SCSSession) Destroy(ctx context.Context) error { + return s.manager.Destroy(ctx) }