diff --git a/cmd/erp.go b/cmd/erp.go index 7ed2b28..e5a9640 100644 --- a/cmd/erp.go +++ b/cmd/erp.go @@ -10,6 +10,7 @@ import ( "management/internal/erpserver" "management/internal/erpserver/biz" "management/internal/erpserver/handler" + "management/internal/pkg/binding" "management/internal/pkg/logger" "management/internal/pkg/middleware" "management/internal/pkg/redis" @@ -65,6 +66,9 @@ func runErp(ctx context.Context) error { rander, err := tpl.New(session, biz.SystemV1().MenuBiz()) checkError(err) + err = binding.SetValidatorTrans("zh") + checkError(err) + handler := handler.NewHandler(conf, rander, redis, session, biz, middleware) address := fmt.Sprintf("%s:%d", conf.App.Host, conf.App.Port) diff --git a/internal/db/query/sys_user.sql b/internal/db/query/sys_user.sql index b308a62..635da08 100644 --- a/internal/db/query/sys_user.sql +++ b/internal/db/query/sys_user.sql @@ -56,11 +56,22 @@ WHERE (NOT @is_status::Boolean OR status = @status) AND (@email::text = '' OR email ILIKE '%' || @email || '%'); -- name: ListSysUserCondition :many -SELECT id, uuid, email, username, avatar, gender, department_id, role_id, status, change_password_at, created_at, updated_at, - (SELECT name FROM sys_department WHERE ID = sys_user.department_id) AS department_name, - (SELECT display_name - FROM sys_role - WHERE id = sys_user.role_id) AS role_name +SELECT id, + uuid, + email, + username, + avatar, + gender, + department_id, + role_id, + status, + change_password_at, + created_at, + updated_at, + COALESCE((SELECT name FROM sys_department WHERE ID = sys_user.department_id), '') AS department_name, + COALESCE((SELECT display_name + FROM sys_role + WHERE id = sys_user.role_id), '') AS role_name FROM sys_user WHERE (NOT @is_status::Boolean OR sys_user.status = @status) AND (NOT @is_id::Boolean OR sys_user.id = @id) diff --git a/internal/db/sqlc/sys_user.sql.go b/internal/db/sqlc/sys_user.sql.go index df92f51..a4826c9 100644 --- a/internal/db/sqlc/sys_user.sql.go +++ b/internal/db/sqlc/sys_user.sql.go @@ -276,11 +276,22 @@ func (q *Queries) ListSysUserByIds(ctx context.Context, dollar_1 []int32) ([]*Sy } const listSysUserCondition = `-- name: ListSysUserCondition :many -SELECT id, uuid, email, username, avatar, gender, department_id, role_id, status, change_password_at, created_at, updated_at, - (SELECT name FROM sys_department WHERE ID = sys_user.department_id) AS department_name, - (SELECT display_name - FROM sys_role - WHERE id = sys_user.role_id) AS role_name +SELECT id, + uuid, + email, + username, + avatar, + gender, + department_id, + role_id, + status, + change_password_at, + created_at, + updated_at, + COALESCE((SELECT name FROM sys_department WHERE ID = sys_user.department_id), '') AS department_name, + COALESCE((SELECT display_name + FROM sys_role + WHERE id = sys_user.role_id), '') AS role_name FROM sys_user WHERE (NOT $1::Boolean OR sys_user.status = $2) AND (NOT $3::Boolean OR sys_user.id = $4) @@ -303,20 +314,20 @@ type ListSysUserConditionParams struct { } type ListSysUserConditionRow struct { - ID int32 `json:"id"` - Uuid uuid.UUID `json:"uuid"` - Email string `json:"email"` - Username string `json:"username"` - Avatar string `json:"avatar"` - Gender int32 `json:"gender"` - DepartmentID int32 `json:"department_id"` - RoleID int32 `json:"role_id"` - Status int32 `json:"status"` - ChangePasswordAt time.Time `json:"change_password_at"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - DepartmentName string `json:"department_name"` - RoleName string `json:"role_name"` + ID int32 `json:"id"` + Uuid uuid.UUID `json:"uuid"` + Email string `json:"email"` + Username string `json:"username"` + Avatar string `json:"avatar"` + Gender int32 `json:"gender"` + DepartmentID int32 `json:"department_id"` + RoleID int32 `json:"role_id"` + Status int32 `json:"status"` + ChangePasswordAt time.Time `json:"change_password_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DepartmentName interface{} `json:"department_name"` + RoleName interface{} `json:"role_name"` } func (q *Queries) ListSysUserCondition(ctx context.Context, arg *ListSysUserConditionParams) ([]*ListSysUserConditionRow, error) { diff --git a/internal/erpserver/biz/v1/system/department.go b/internal/erpserver/biz/v1/system/department.go index b17836f..04e139c 100644 --- a/internal/erpserver/biz/v1/system/department.go +++ b/internal/erpserver/biz/v1/system/department.go @@ -3,19 +3,23 @@ package system import ( "context" "encoding/json" + "errors" + "fmt" "strconv" "time" "management/internal/db/model/dto" db "management/internal/db/sqlc" + "management/internal/erpserver/model/form" "management/internal/erpserver/model/view" + "management/internal/pkg/convertor" "management/internal/pkg/know" "management/internal/pkg/redis" ) type DepartmentBiz interface { - Create(ctx context.Context, arg *db.CreateSysDepartmentParams) (*db.SysDepartment, error) - Update(ctx context.Context, arg *db.UpdateSysDepartmentParams) (*db.SysDepartment, error) + Create(ctx context.Context, req *form.Department) error + Update(ctx context.Context, req *form.Department) error All(ctx context.Context) ([]*db.SysDepartment, error) List(ctx context.Context, q dto.SearchDto) ([]*db.SysDepartment, int64, error) Get(ctx context.Context, id int32) (*db.SysDepartment, error) @@ -96,12 +100,79 @@ func (b *departmentBiz) Get(ctx context.Context, id int32) (*db.SysDepartment, e return b.store.GetSysDepartment(ctx, id) } -func (b *departmentBiz) Create(ctx context.Context, arg *db.CreateSysDepartmentParams) (*db.SysDepartment, error) { - return b.store.CreateSysDepartment(ctx, arg) +func (b *departmentBiz) Create(ctx context.Context, req *form.Department) error { + parent := &db.SysDepartment{ + ID: 0, + ParentID: 0, + ParentPath: ",0,", + } + if *req.ParentID > 0 { + var err error + parent, err = b.store.GetSysDepartment(ctx, *req.ParentID) + if err != nil { + return errors.New("父级节点错误") + } + } + + var order int32 = 6666 + if *req.Sort > 0 { + order = *req.Sort + } + + arg := &db.CreateSysDepartmentParams{ + Name: req.Name, + ParentID: parent.ID, + ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)), + Status: *req.Status, + Sort: order, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + _, err := b.store.CreateSysDepartment(ctx, arg) + if err != nil { + if db.IsUniqueViolation(err) { + return errors.New("部门已存在") + } + return err + } + return nil } -func (b *departmentBiz) Update(ctx context.Context, arg *db.UpdateSysDepartmentParams) (*db.SysDepartment, error) { - return b.store.UpdateSysDepartment(ctx, arg) +func (b *departmentBiz) Update(ctx context.Context, req *form.Department) error { + parent := &db.SysDepartment{ + ID: 0, + ParentID: 0, + ParentPath: ",0,", + } + if *req.ParentID > 0 { + var err error + parent, err = b.store.GetSysDepartment(ctx, *req.ParentID) + if err != nil { + return errors.New("父级节点错误") + } + } + + depart, err := b.store.GetSysDepartment(ctx, *req.ID) + if err != nil { + return err + } + + var order int32 = 6666 + if *req.Sort > 0 { + order = *req.Sort + } + + arg := &db.UpdateSysDepartmentParams{ + ID: depart.ID, + Name: req.Name, + ParentID: parent.ID, + ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)), + Status: *req.Status, + Sort: order, + UpdatedAt: time.Now(), + } + _, err = b.store.UpdateSysDepartment(ctx, arg) + return err } func (b *departmentBiz) Refresh(ctx context.Context) ([]*db.SysDepartment, error) { diff --git a/internal/erpserver/biz/v1/system/role.go b/internal/erpserver/biz/v1/system/role.go index 5c673d2..a085702 100644 --- a/internal/erpserver/biz/v1/system/role.go +++ b/internal/erpserver/biz/v1/system/role.go @@ -3,19 +3,24 @@ package system import ( "context" "encoding/json" + "errors" + "fmt" "strconv" "time" "management/internal/db/model/dto" db "management/internal/db/sqlc" + "management/internal/erpserver/model/form" "management/internal/erpserver/model/view" + "management/internal/pkg/convertor" "management/internal/pkg/know" "management/internal/pkg/redis" ) type RoleBiz interface { - Create(ctx context.Context, arg *db.CreateSysRoleParams) (*db.SysRole, error) - Update(ctx context.Context, arg *db.UpdateSysRoleParams) (*db.SysRole, error) + Create(ctx context.Context, req *form.Role) error + Update(ctx context.Context, req *form.Role) error + CreateOrUpdate(ctx context.Context, req *form.Role) error All(ctx context.Context) ([]*db.SysRole, error) List(ctx context.Context, q dto.SearchDto) ([]*db.SysRole, int64, error) Get(ctx context.Context, id int32) (*db.SysRole, error) @@ -44,12 +49,140 @@ func NewRole(store db.Store, redis redis.IRedis) *roleBiz { } } -func (b *roleBiz) Create(ctx context.Context, arg *db.CreateSysRoleParams) (*db.SysRole, error) { - return b.store.CreateSysRole(ctx, arg) +func (b *roleBiz) Create(ctx context.Context, req *form.Role) error { + parent := &db.SysRole{ + ID: 0, + ParentID: 0, + ParentPath: ",0,", + } + if *req.ParentID > 0 { + var err error + parent, err = b.store.GetSysRole(ctx, *req.ParentID) + if err != nil { + return errors.New("父级节点错误") + } + } + + var order int32 = 6666 + if *req.Sort > 0 { + order = *req.Sort + } + + arg := &db.CreateSysRoleParams{ + Name: req.Name, + DisplayName: req.DisplayName, + Vip: false, + ParentID: parent.ID, + ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)), + Status: *req.Status, + Sort: order, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + _, err := b.store.CreateSysRole(ctx, arg) + if err != nil { + if db.IsUniqueViolation(err) { + return errors.New("角色名称已存在") + } + return err + } + return nil } -func (b *roleBiz) Update(ctx context.Context, arg *db.UpdateSysRoleParams) (*db.SysRole, error) { - return b.store.UpdateSysRole(ctx, arg) +func (b *roleBiz) Update(ctx context.Context, req *form.Role) error { + parent := &db.SysRole{ + ID: 0, + ParentID: 0, + ParentPath: ",0,", + } + if *req.ParentID > 0 { + var err error + parent, err = b.store.GetSysRole(ctx, *req.ParentID) + if err != nil { + return errors.New("父级节点错误") + } + } + + role, err := b.store.GetSysRole(ctx, *req.ID) + if err != nil { + return err + } + + var order int32 = 6666 + if *req.Sort > 0 { + order = *req.Sort + } + + arg := &db.UpdateSysRoleParams{ + ID: role.ID, + DisplayName: req.DisplayName, + Status: *req.Status, + ParentID: parent.ID, + ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)), + Sort: order, + UpdatedAt: time.Now(), + } + _, err = b.store.UpdateSysRole(ctx, arg) + return err +} + +func (b *roleBiz) CreateOrUpdate(ctx context.Context, req *form.Role) error { + parent := &db.SysRole{ + ID: 0, + ParentID: 0, + ParentPath: ",0,", + } + if *req.ParentID > 0 { + var err error + parent, err = b.store.GetSysRole(ctx, *req.ParentID) + if err != nil { + return errors.New("父级节点错误") + } + } + + var order int32 = 6666 + if *req.Sort > 0 { + order = *req.Sort + } + + if *req.ID > 0 { + role, err := b.store.GetSysRole(ctx, *req.ID) + if err != nil { + return err + } + + arg := &db.UpdateSysRoleParams{ + ID: role.ID, + DisplayName: req.DisplayName, + Status: *req.Status, + ParentID: parent.ID, + ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)), + Sort: order, + UpdatedAt: time.Now(), + } + _, err = b.store.UpdateSysRole(ctx, arg) + return err + } else { + arg := &db.CreateSysRoleParams{ + Name: req.Name, + DisplayName: req.DisplayName, + Vip: false, + ParentID: parent.ID, + ParentPath: convertor.HandleParentPath(fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID)), + Status: *req.Status, + Sort: order, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + _, err := b.store.CreateSysRole(ctx, arg) + if err != nil { + if db.IsUniqueViolation(err) { + return errors.New("角色名称已存在") + } + return err + } + return nil + } } func (b *roleBiz) All(ctx context.Context) ([]*db.SysRole, error) { diff --git a/internal/erpserver/biz/v1/system/user.go b/internal/erpserver/biz/v1/system/user.go index 466c3d9..eaddce6 100644 --- a/internal/erpserver/biz/v1/system/user.go +++ b/internal/erpserver/biz/v1/system/user.go @@ -9,17 +9,20 @@ import ( "management/internal/db/model/dto" db "management/internal/db/sqlc" - "management/internal/erpserver/model/req" + "management/internal/erpserver/model/form" "management/internal/erpserver/model/view" "management/internal/pkg/crypto" "management/internal/pkg/know" + "management/internal/pkg/rand" "management/internal/pkg/session" + + "github.com/google/uuid" ) // UserBiz 定义处理用户请求所需的方法. type UserBiz interface { - Create(ctx context.Context, req *db.CreateSysUserParams) (*db.SysUser, error) - Update(ctx context.Context, req *db.UpdateSysUserParams) (*db.SysUser, error) + Create(ctx context.Context, req *form.User) error + Update(ctx context.Context, req *form.User) error All(ctx context.Context) ([]*db.SysUser, error) List(ctx context.Context, q dto.SearchDto) ([]*db.ListSysUserConditionRow, int64, error) Get(ctx context.Context, id int32) (*db.SysUser, error) @@ -31,7 +34,7 @@ type UserBiz interface { // UserExpansion 定义用户操作的扩展方法. type UserExpansion interface { - Login(ctx context.Context, req *req.Login) error + Login(ctx context.Context, req *form.Login) error } // userBiz 是 UserBiz 接口的实现. @@ -50,12 +53,75 @@ func NewUser(store db.Store, session session.ISession) *userBiz { } } -func (b *userBiz) Create(ctx context.Context, req *db.CreateSysUserParams) (*db.SysUser, error) { - return b.store.CreateSysUser(ctx, req) +func (b *userBiz) Create(ctx context.Context, req *form.User) error { + salt, err := rand.String(10) + if err != nil { + return err + } + + hashedPassword, err := crypto.BcryptHashPassword(req.Password + salt) + if err != nil { + return err + } + + initTime, err := time.ParseInLocation(time.DateTime, "0001-01-01 00:00:00", time.Local) + if err != nil { + return err + } + + arg := &db.CreateSysUserParams{ + Uuid: uuid.Must(uuid.NewV7()), + Email: req.Email, + Username: req.Username, + HashedPassword: hashedPassword, + Salt: salt, + Avatar: req.Avatar, + Gender: req.Gender, + DepartmentID: req.DepartmentID, + RoleID: req.RoleID, + Status: *req.Status, + ChangePasswordAt: initTime, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + _, err = b.store.CreateSysUser(ctx, arg) + if err != nil { + if db.IsUniqueViolation(err) { + return errors.New("用户已经存在") + } + return err + } + return nil } -func (b *userBiz) Update(ctx context.Context, req *db.UpdateSysUserParams) (*db.SysUser, error) { - return b.store.UpdateSysUser(ctx, req) +func (b *userBiz) Update(ctx context.Context, req *form.User) error { + user, err := b.store.GetSysUser(ctx, *req.ID) + if err != nil { + return err + } + + arg := &db.UpdateSysUserParams{ + ID: user.ID, + Username: req.Username, + HashedPassword: user.HashedPassword, + Avatar: req.Avatar, + Gender: req.Gender, + DepartmentID: req.DepartmentID, + RoleID: req.RoleID, + Status: *req.Status, + ChangePasswordAt: user.ChangePasswordAt, + UpdatedAt: time.Now(), + } + if req.ChangePassword == "on" { + hashedPassword, err := crypto.BcryptHashPassword(req.Password + user.Salt) + if err != nil { + return err + } + arg.HashedPassword = hashedPassword + arg.ChangePasswordAt = time.Now() + } + _, err = b.store.UpdateSysUser(ctx, arg) + return err } func (b *userBiz) All(ctx context.Context) ([]*db.SysUser, error) { @@ -112,7 +178,7 @@ func (b *userBiz) XmSelect(ctx context.Context) ([]*view.XmSelect, error) { return res, nil } -func (b *userBiz) Login(ctx context.Context, req *req.Login) error { +func (b *userBiz) Login(ctx context.Context, req *form.Login) error { log := &db.CreateSysUserLoginLogParams{ CreatedAt: time.Now(), Email: req.Email, diff --git a/internal/erpserver/handler/system/department.go b/internal/erpserver/handler/system/department.go index 6444d55..35021f5 100644 --- a/internal/erpserver/handler/system/department.go +++ b/internal/erpserver/handler/system/department.go @@ -1,13 +1,13 @@ package system import ( - "fmt" "net/http" - "time" "management/internal/db/model/dto" db "management/internal/db/sqlc" "management/internal/erpserver/biz" + "management/internal/erpserver/model/form" + "management/internal/pkg/binding" "management/internal/pkg/convertor" "management/internal/pkg/tpl" ) @@ -95,61 +95,22 @@ func (h *departmentHandler) Edit(w http.ResponseWriter, r *http.Request) { } func (h *departmentHandler) Save(w http.ResponseWriter, r *http.Request) { - id := convertor.ConvertInt[int32](r.PostFormValue("ID"), 0) - ParentID := convertor.ConvertInt[int32](r.PostFormValue("ParentID"), 0) - name := r.PostFormValue("Name") - sort := convertor.ConvertInt[int32](r.PostFormValue("Sort"), 6666) - status := convertor.ConvertInt[int32](r.PostFormValue("Status"), 9999) - - ctx := r.Context() - var parent *db.SysDepartment - if ParentID > 0 { - var err error - parent, err = h.biz.SystemV1().DepartmentBiz().Get(ctx, ParentID) - if err != nil { - h.render.JSONERR(w, "父级节点错误") - return - } + var req form.Department + if err := binding.Form.Bind(r, &req); err != nil { + h.render.JSONERR(w, binding.ValidatorErrors(err)) + return } - if id == 0 { - arg := db.CreateSysDepartmentParams{ - Name: name, - ParentID: ParentID, - ParentPath: fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID), - Status: status, - Sort: sort, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - _, err := h.biz.SystemV1().DepartmentBiz().Create(ctx, &arg) + ctx := r.Context() + if *req.ID == 0 { + err := h.biz.SystemV1().DepartmentBiz().Create(ctx, &req) if err != nil { - if db.IsUniqueViolation(err) { - h.render.JSONERR(w, "部门名称已存在") - return - } h.render.JSONERR(w, err.Error()) return } - h.render.JSONOK(w, "添加成功") } else { - res, err := h.biz.SystemV1().DepartmentBiz().Get(ctx, id) - if err != nil { - h.render.JSONERR(w, err.Error()) - return - } - - arg := &db.UpdateSysDepartmentParams{ - ID: res.ID, - Name: name, - ParentID: ParentID, - ParentPath: fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID), - Status: status, - Sort: sort, - UpdatedAt: time.Now(), - } - _, err = h.biz.SystemV1().DepartmentBiz().Update(ctx, arg) + err := h.biz.SystemV1().DepartmentBiz().Update(ctx, &req) if err != nil { h.render.JSONERR(w, err.Error()) return diff --git a/internal/erpserver/handler/system/home.go b/internal/erpserver/handler/system/home.go new file mode 100644 index 0000000..388f7ab --- /dev/null +++ b/internal/erpserver/handler/system/home.go @@ -0,0 +1,11 @@ +package system + +import "net/http" + +func (h *systemHandler) Home(w http.ResponseWriter, r *http.Request) { + h.render.HTML(w, r, "home/home.tmpl", nil) +} + +func (h *systemHandler) Dashboard(w http.ResponseWriter, r *http.Request) { + h.render.HTML(w, r, "home/dashboard.tmpl", nil) +} diff --git a/internal/erpserver/handler/system/role.go b/internal/erpserver/handler/system/role.go index 93da345..aa4105a 100644 --- a/internal/erpserver/handler/system/role.go +++ b/internal/erpserver/handler/system/role.go @@ -1,14 +1,14 @@ package system import ( - "fmt" "net/http" "strings" - "time" "management/internal/db/model/dto" db "management/internal/db/sqlc" "management/internal/erpserver/biz" + "management/internal/erpserver/model/form" + "management/internal/pkg/binding" "management/internal/pkg/convertor" "management/internal/pkg/tpl" ) @@ -98,69 +98,22 @@ func (h *roleHandler) Edit(w http.ResponseWriter, r *http.Request) { } func (h *roleHandler) Save(w http.ResponseWriter, r *http.Request) { - id := convertor.ConvertInt[int32](r.PostFormValue("ID"), 0) - name := r.PostFormValue("Name") - parentID := convertor.ConvertInt[int32](r.PostFormValue("ParentID"), 0) - displayName := r.PostFormValue("DisplayName") - sort := convertor.ConvertInt[int32](r.PostFormValue("Sort"), 6666) - status := convertor.ConvertInt[int32](r.PostFormValue("Status"), 0) - - ctx := r.Context() - var parent *db.SysRole - if parentID > 0 { - var err error - parent, err = h.biz.SystemV1().RoleBiz().Get(ctx, parentID) - if err != nil { - h.render.JSONERR(w, "父级节点错误") - return - } - } else { - parent = &db.SysRole{ - ID: 0, - ParentID: 0, - ParentPath: ",0,", - } + var req form.Role + if err := binding.Form.Bind(r, &req); err != nil { + h.render.JSONERR(w, binding.ValidatorErrors(err)) + return } - if id == 0 { - arg := &db.CreateSysRoleParams{ - Name: name, - DisplayName: displayName, - Vip: false, - ParentID: parent.ID, - ParentPath: fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID), - Status: status, - Sort: sort, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - _, err := h.biz.SystemV1().RoleBiz().Create(ctx, arg) + ctx := r.Context() + if *req.ID == 0 { + err := h.biz.SystemV1().RoleBiz().Create(ctx, &req) if err != nil { - if db.IsUniqueViolation(err) { - h.render.JSONERR(w, "角色名称已存在") - return - } h.render.JSONERR(w, err.Error()) return } - h.render.JSONOK(w, "添加成功") } else { - res, err := h.biz.SystemV1().RoleBiz().Get(ctx, id) - if err != nil { - h.render.JSONERR(w, err.Error()) - return - } - arg := &db.UpdateSysRoleParams{ - ID: res.ID, - DisplayName: displayName, - Status: status, - ParentID: parent.ID, - ParentPath: fmt.Sprintf("%s,%d,", parent.ParentPath, parent.ID), - Sort: sort, - UpdatedAt: time.Now(), - } - _, err = h.biz.SystemV1().RoleBiz().Update(ctx, arg) + err := h.biz.SystemV1().RoleBiz().Update(ctx, &req) if err != nil { h.render.JSONERR(w, err.Error()) return diff --git a/internal/erpserver/handler/system/system.go b/internal/erpserver/handler/system/system.go index 9459de0..52a0c66 100644 --- a/internal/erpserver/handler/system/system.go +++ b/internal/erpserver/handler/system/system.go @@ -12,6 +12,7 @@ import ( type SystemHandler interface { Home(w http.ResponseWriter, req *http.Request) + Dashboard(w http.ResponseWriter, req *http.Request) UserHandler() UserHandler MenuHandler() MenuHandler RoleHandler() RoleHandler @@ -42,10 +43,6 @@ func NewSystemHandler(render tpl.Renderer, redis redis.IRedis, session session.I } } -func (h *systemHandler) Home(w http.ResponseWriter, r *http.Request) { - h.render.HTML(w, r, "home/home.tmpl", nil) -} - func (h *systemHandler) UserHandler() UserHandler { return NewUserHandler(h.render, h.session, h.biz, h.mi) } diff --git a/internal/erpserver/handler/system/user.go b/internal/erpserver/handler/system/user.go index 84066fc..e34b7f9 100644 --- a/internal/erpserver/handler/system/user.go +++ b/internal/erpserver/handler/system/user.go @@ -3,23 +3,19 @@ package system import ( "encoding/json" "net/http" - "strings" - "time" "management/internal/db/model/dto" db "management/internal/db/sqlc" "management/internal/erpserver/biz" - "management/internal/erpserver/model/req" + "management/internal/erpserver/model/form" + "management/internal/pkg/binding" "management/internal/pkg/convertor" - "management/internal/pkg/crypto" "management/internal/pkg/know" "management/internal/pkg/middleware" - "management/internal/pkg/rand" "management/internal/pkg/session" "management/internal/pkg/tpl" "management/internal/pkg/tpl/html" - "github.com/google/uuid" "github.com/zhang2092/browser" ) @@ -74,7 +70,7 @@ func (h *userHandler) Edit(w http.ResponseWriter, r *http.Request) { if id > 0 { ctx := r.Context() if user, err := h.biz.SystemV1().UserBiz().Get(ctx, id); err == nil { - user.HashedPassword = nil + user.HashedPassword = []byte("********") sysUser = user } } @@ -84,109 +80,36 @@ func (h *userHandler) Edit(w http.ResponseWriter, r *http.Request) { } func (h *userHandler) Save(w http.ResponseWriter, r *http.Request) { - id := convertor.ConvertInt[int32](r.PostFormValue("ID"), 0) - email := r.PostFormValue("Email") - username := r.PostFormValue("Username") - password := r.PostFormValue("Password") - changePassword := r.PostFormValue("ChangePassword") - gender := convertor.ConvertInt[int32](r.PostFormValue("Gender"), 0) - avatar := r.PostFormValue("File") - status := convertor.ConvertInt[int32](r.PostFormValue("Status"), 0) + var req form.User + if err := binding.Form.Bind(r, &req); err != nil { + h.render.JSONERR(w, binding.ValidatorErrors(err)) + return + } ctx := r.Context() - departmentID := convertor.ConvertInt[int32](r.PostFormValue("DepartmentID"), 0) - var department *db.SysDepartment - var err error - if departmentID > 0 { - department, err = h.biz.SystemV1().DepartmentBiz().Get(ctx, departmentID) - if err != nil { + if req.DepartmentID > 0 { + if _, err := h.biz.SystemV1().DepartmentBiz().Get(ctx, req.DepartmentID); err != nil { h.render.JSONERR(w, "部门数据错误") return } } - var role *db.SysRole - roleID := convertor.ConvertInt[int32](r.PostFormValue("RoleID"), 0) - if roleID > 0 { - role, err = h.biz.SystemV1().RoleBiz().Get(ctx, roleID) - if err != nil { + if req.RoleID > 0 { + if _, err := h.biz.SystemV1().RoleBiz().Get(ctx, req.RoleID); err != nil { h.render.JSONERR(w, "角色数据错误") return } } - if id == 0 { - salt, err := rand.String(10) + if *req.ID == 0 { + err := h.biz.SystemV1().UserBiz().Create(ctx, &req) if err != nil { h.render.JSONERR(w, err.Error()) return } - hashedPassword, err := crypto.BcryptHashPassword(password + salt) - if err != nil { - h.render.JSONERR(w, err.Error()) - return - } - - initTime, err := time.ParseInLocation(time.DateTime, "0001-01-01 00:00:00", time.Local) - if err != nil { - h.render.JSONERR(w, err.Error()) - return - } - arg := &db.CreateSysUserParams{ - Uuid: uuid.Must(uuid.NewV7()), - Email: email, - Username: username, - HashedPassword: hashedPassword, - Salt: salt, - Avatar: avatar, - Gender: gender, - DepartmentID: department.ID, - RoleID: role.ID, - Status: status, - ChangePasswordAt: initTime, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - _, err = h.biz.SystemV1().UserBiz().Create(ctx, arg) - if err != nil { - if db.IsUniqueViolation(err) { - h.render.JSONERR(w, "数据已存在") - return - } - h.render.JSONERR(w, err.Error()) - return - } - h.render.JSONOK(w, "添加成功") } else { - res, err := h.biz.SystemV1().UserBiz().Get(ctx, id) - if err != nil { - h.render.JSONERR(w, err.Error()) - return - } - - arg := &db.UpdateSysUserParams{ - ID: res.ID, - Username: username, - HashedPassword: res.HashedPassword, - Avatar: avatar, - Gender: int32(gender), - DepartmentID: department.ID, - RoleID: role.ID, - Status: int32(status), - ChangePasswordAt: res.ChangePasswordAt, - UpdatedAt: time.Now(), - } - if changePassword == "on" { - hashedPassword, err := crypto.BcryptHashPassword(password + res.Salt) - if err != nil { - h.render.JSONERR(w, err.Error()) - return - } - arg.HashedPassword = hashedPassword - arg.ChangePasswordAt = time.Now() - } - _, err = h.biz.SystemV1().UserBiz().Update(ctx, arg) + err := h.biz.SystemV1().UserBiz().Update(ctx, &req) if err != nil { h.render.JSONERR(w, err.Error()) return @@ -261,7 +184,7 @@ func (h *userHandler) Login(w http.ResponseWriter, r *http.Request) { var user dto.AuthorizeUser u := h.session.GetBytes(ctx, know.StoreName) if err := json.Unmarshal(u, &user); err == nil { - // 判断租户是否一致, 一致则刷新令牌,跳转到首页 + // 判断用户是否登陆, 已经登陆则刷新令牌,跳转到首页 if err := h.session.RenewToken(ctx); err == nil { h.session.Put(ctx, know.StoreName, u) http.Redirect(w, r, "/home.html", http.StatusFound) @@ -271,30 +194,16 @@ func (h *userHandler) Login(w http.ResponseWriter, r *http.Request) { h.session.Destroy(ctx) h.render.HTML(w, r, "oauth/login.tmpl", nil) case http.MethodPost: - req := &req.Login{ - Email: strings.TrimSpace(r.PostFormValue("email")), - Password: strings.TrimSpace(r.PostFormValue("password")), - CaptchaID: strings.TrimSpace(r.PostFormValue("captcha_id")), - Captcha: strings.TrimSpace(r.PostFormValue("captcha")), - Ip: r.RemoteAddr, - Referrer: r.Header.Get("Referer"), - Url: r.URL.RequestURI(), + defer r.Body.Close() + var req form.Login + if err := binding.Form.Bind(r, &req); err != nil { + e := binding.ValidatorErrors(err) + h.render.JSONERR(w, e) + return } - if len(req.Email) == 0 { - h.render.JSON(w, tpl.Response{Success: false, Message: "请填写邮箱"}) - return - } - if len(req.Password) == 0 { - h.render.JSON(w, tpl.Response{Success: false, Message: "请填写密码"}) - return - } - if len(req.Captcha) == 0 { - h.render.JSON(w, tpl.Response{Success: false, Message: "请填写验证码"}) - return - } if !h.biz.CommonV1().CaptchaBiz().Verify(req.CaptchaID, req.Captcha, true) { - h.render.JSON(w, tpl.Response{Success: false, Message: "验证码错误"}) + h.render.JSONERR(w, "验证码错误") return } @@ -304,9 +213,12 @@ func (h *userHandler) Login(w http.ResponseWriter, r *http.Request) { return } + req.Ip = r.RemoteAddr + req.Referrer = r.Header.Get("Referer") + req.Url = r.URL.RequestURI() req.Os = br.Platform().Name() req.Browser = br.Name() - err = h.biz.SystemV1().UserBiz().Login(ctx, req) + err = h.biz.SystemV1().UserBiz().Login(ctx, &req) if err != nil { h.render.JSONERR(w, err.Error()) return diff --git a/internal/erpserver/http.go b/internal/erpserver/http.go index 3336ded..eb8c00a 100644 --- a/internal/erpserver/http.go +++ b/internal/erpserver/http.go @@ -38,7 +38,8 @@ func NewRouter(handler handler.IHandler, mw mw.IMiddleware) *chi.Mux { r.With(mw.Authorize, mw.Audit).Post("/upload/file", handler.CommonHandler().UploadHandler().File) r.With(mw.Authorize, mw.Audit).Post("/upload/mutilfile", handler.CommonHandler().UploadHandler().MutilFiles) - r.With(mw.Authorize, mw.Audit).Get("/home.html", handler.SystemHandler().Home) + r.With(mw.Authorize).Get("/home.html", handler.SystemHandler().Home) + r.With(mw.Authorize).Get("/dashboard", handler.SystemHandler().Dashboard) r.With(mw.Authorize).Get("/pear.json", handler.SystemHandler().ConfigHandler().Pear) r.Route("/system", func(r chi.Router) { diff --git a/internal/erpserver/model/form/department.go b/internal/erpserver/model/form/department.go new file mode 100644 index 0000000..af56c00 --- /dev/null +++ b/internal/erpserver/model/form/department.go @@ -0,0 +1,9 @@ +package form + +type Department struct { + ID *int32 `form:"id" binding:"required"` + Name string `form:"name" binding:"required"` + ParentID *int32 `form:"parent_id" binding:"required"` + Sort *int32 `form:"sort"` + Status *int32 `form:"status" binding:"required"` +} diff --git a/internal/erpserver/model/form/role.go b/internal/erpserver/model/form/role.go new file mode 100644 index 0000000..e8bda30 --- /dev/null +++ b/internal/erpserver/model/form/role.go @@ -0,0 +1,10 @@ +package form + +type Role struct { + ID *int32 `form:"id" binding:"required"` + Name string `form:"name" binding:"required"` + ParentID *int32 `form:"parent_id" binding:"required"` + DisplayName string `form:"display_name" binding:"required"` + Sort *int32 `form:"sort"` + Status *int32 `form:"status" binding:"required"` +} diff --git a/internal/erpserver/model/form/user.go b/internal/erpserver/model/form/user.go new file mode 100644 index 0000000..b8ae7c8 --- /dev/null +++ b/internal/erpserver/model/form/user.go @@ -0,0 +1,28 @@ +package form + +type Login struct { + Email string `form:"email" binding:"required,email"` + Password string `form:"password" binding:"required,min=6"` + Captcha string `form:"captcha" binding:"required"` + CaptchaID string `form:"captcha_id" binding:"required"` + + // 平台信息 + Os string + Ip string + Browser string + Referrer string + Url string +} + +type User struct { + ID *int32 `form:"id" binding:"required"` + Email string `form:"email" binding:"required,email"` + Username string `form:"username" binding:"required"` + Password string `form:"password" binding:"required,min=6"` + ChangePassword string `form:"change_password"` + Avatar string `form:"File"` + Gender int32 `form:"gender"` + DepartmentID int32 `form:"department_id"` + RoleID int32 `form:"role_id"` + Status *int32 `form:"status" binding:"required"` +} diff --git a/internal/erpserver/model/req/user.go b/internal/erpserver/model/req/user.go deleted file mode 100644 index d098ef5..0000000 --- a/internal/erpserver/model/req/user.go +++ /dev/null @@ -1,15 +0,0 @@ -package req - -type Login struct { - Email string `json:"email"` - Password string `json:"password"` - Captcha string `json:"captcha"` - CaptchaID string `json:"captcha_id"` - - // 平台信息 - Os string `json:"os"` - Ip string `json:"ip"` - Browser string `json:"browser"` - Referrer string `json:"referrer"` - Url string `json:"url"` -} diff --git a/internal/pkg/binding/binding.go b/internal/pkg/binding/binding.go new file mode 100644 index 0000000..a455867 --- /dev/null +++ b/internal/pkg/binding/binding.go @@ -0,0 +1,139 @@ +package binding + +import ( + "net/http" +) + +const ( + MIMEJSON = "application/json" + MIMEHTML = "text/html" + MIMEXML = "application/xml" + MIMEXML2 = "text/xml" + MIMEPlain = "text/plain" + MIMEPOSTForm = "application/x-www-form-urlencoded" + MIMEMultipartPOSTForm = "multipart/form-data" + MIMEPROTOBUF = "application/x-protobuf" + MIMEMSGPACK = "application/x-msgpack" + MIMEMSGPACK2 = "application/msgpack" + MIMEYAML = "application/x-yaml" + MIMEYAML2 = "application/yaml" + MIMETOML = "application/toml" +) + +type Binding interface { + Name() string + Bind(*http.Request, any) error +} + +type BindingBody interface { + Binding + BindBody([]byte, any) error +} + +type BindingUri interface { + Name() string + BindUri(map[string][]string, any) error +} + +// StructValidator is the minimal interface which needs to be implemented in +// order for it to be used as the validator engine for ensuring the correctness +// of the request. Gin provides a default implementation for this using +// https://github.com/go-playground/validator/tree/v10.6.1. +type StructValidator interface { + // ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right. + // If the received type is a slice|array, the validation should be performed travel on every element. + // If the received type is not a struct or slice|array, any validation should be skipped and nil must be returned. + // If the received type is a struct or pointer to a struct, the validation should be performed. + // If the struct is not valid or the validation itself fails, a descriptive error should be returned. + // Otherwise nil must be returned. + ValidateStruct(any) error + + // Engine returns the underlying validator engine which powers the + // StructValidator implementation. + Engine() any +} + +// Validator is the default validator which implements the StructValidator +// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1 +// under the hood. +var Validator StructValidator = &defaultValidator{} + +var ( + JSON BindingBody = jsonBinding{} + // XML BindingBody = xmlBinding{} + Form Binding = formBinding{} + Query Binding = queryBinding{} + FormPost Binding = formPostBinding{} + FormMultipart Binding = formMultipartBinding{} + // ProtoBuf BindingBody = protobufBinding{} + // MsgPack BindingBody = msgpackBinding{} + // YAML BindingBody = yamlBinding{} + // Uri BindingUri = uriBinding{} + // Header Binding = headerBinding{} + // Plain BindingBody = plainBinding{} + // TOML BindingBody = tomlBinding{} +) + +// Default returns the appropriate Binding instance based on the HTTP method +// and the content type. +func Default(method, contentType string) Binding { + if method == http.MethodGet { + return Form + } + + switch contentType { + case MIMEJSON: + return JSON + // case MIMEXML, MIMEXML2: + // return XML + // case MIMEPROTOBUF: + // return ProtoBuf + // case MIMEMSGPACK, MIMEMSGPACK2: + // return MsgPack + // case MIMEYAML, MIMEYAML2: + // return YAML + // case MIMETOML: + // return TOML + case MIMEMultipartPOSTForm: + return FormMultipart + default: // case MIMEPOSTForm: + return Form + } +} + +func validate(obj any) error { + if Validator == nil { + return nil + } + return Validator.ValidateStruct(obj) +} + +// ShouldBind checks the Method and Content-Type to select a binding engine automatically, +// Depending on the "Content-Type" header different bindings are used, for example: +// +// "application/json" --> JSON binding +// "application/xml" --> XML binding +// +// It parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input. +// It decodes the json payload into the struct specified as a pointer. +// Like c.Bind() but this method does not set the response status code to 400 or abort if input is not valid. +func ShouldBind(r *http.Request, obj any) error { + b := Default(r.Method, r.Header.Get("Content-Type")) + return ShouldBindWith(r, obj, b) +} + +// ShouldBindJSON is a shortcut for c.ShouldBindWith(obj, JSON). +func ShouldBindJSON(r *http.Request, obj any) error { + return ShouldBindWith(r, obj, JSON) +} + +// ShouldBindWith binds the passed struct pointer using the specified binding engine. +// See the binding package. +func ShouldBindWith(r *http.Request, obj any, b Binding) error { + return b.Bind(r, obj) +} + +// ShouldBindQuery is a shortcut for c.ShouldBindWith(obj, Query). +func ShouldBindQuery(r *http.Request, obj any) error { + return ShouldBindWith(r, obj, Query) +} diff --git a/internal/pkg/binding/byteconv.go b/internal/pkg/binding/byteconv.go new file mode 100644 index 0000000..1ecdd5e --- /dev/null +++ b/internal/pkg/binding/byteconv.go @@ -0,0 +1,15 @@ +package binding + +import "unsafe" + +// stringToBytes converts string to byte slice without a memory allocation. +// For more details, see https://github.com/golang/go/issues/53003#issuecomment-1140276077. +func stringToBytes(s string) []byte { + return unsafe.Slice(unsafe.StringData(s), len(s)) +} + +// bytesToString converts byte slice to string without a memory allocation. +// For more details, see https://github.com/golang/go/issues/53003#issuecomment-1140276077. +func bytesToString(b []byte) string { + return unsafe.String(unsafe.SliceData(b), len(b)) +} diff --git a/internal/pkg/binding/default_validator.go b/internal/pkg/binding/default_validator.go new file mode 100644 index 0000000..8da9a3f --- /dev/null +++ b/internal/pkg/binding/default_validator.go @@ -0,0 +1,91 @@ +package binding + +import ( + "reflect" + "strconv" + "strings" + "sync" + + "github.com/go-playground/validator/v10" +) + +type defaultValidator struct { + once sync.Once + validate *validator.Validate +} + +type SliceValidationError []error + +// Error concatenates all error elements in SliceValidationError into a single string separated by \n. +func (err SliceValidationError) Error() string { + if len(err) == 0 { + return "" + } + + var b strings.Builder + for i := 0; i < len(err); i++ { + if err[i] != nil { + if b.Len() > 0 { + b.WriteString("\n") + } + b.WriteString("[" + strconv.Itoa(i) + "]: " + err[i].Error()) + } + } + return b.String() +} + +var _ StructValidator = (*defaultValidator)(nil) + +// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. +func (v *defaultValidator) ValidateStruct(obj any) error { + if obj == nil { + return nil + } + + value := reflect.ValueOf(obj) + switch value.Kind() { + case reflect.Ptr: + if value.Elem().Kind() != reflect.Struct { + return v.ValidateStruct(value.Elem().Interface()) + } + return v.validateStruct(obj) + case reflect.Struct: + return v.validateStruct(obj) + case reflect.Slice, reflect.Array: + count := value.Len() + validateRet := make(SliceValidationError, 0) + for i := 0; i < count; i++ { + if err := v.ValidateStruct(value.Index(i).Interface()); err != nil { + validateRet = append(validateRet, err) + } + } + if len(validateRet) == 0 { + return nil + } + return validateRet + default: + return nil + } +} + +// validateStruct receives struct type +func (v *defaultValidator) validateStruct(obj any) error { + v.lazyinit() + return v.validate.Struct(obj) +} + +// Engine returns the underlying validator engine which powers the default +// Validator instance. This is useful if you want to register custom validations +// or struct level validations. See validator GoDoc for more info - +// https://pkg.go.dev/github.com/go-playground/validator/v10 +func (v *defaultValidator) Engine() any { + v.lazyinit() + return v.validate +} + +func (v *defaultValidator) lazyinit() { + v.once.Do(func() { + v.validate = validator.New() + v.validate.SetTagName("binding") + }) +} diff --git a/internal/pkg/binding/form.go b/internal/pkg/binding/form.go new file mode 100644 index 0000000..5ee9811 --- /dev/null +++ b/internal/pkg/binding/form.go @@ -0,0 +1,60 @@ +package binding + +import ( + "errors" + "net/http" +) + +const defaultMemory = 32 << 20 + +type ( + formBinding struct{} + formPostBinding struct{} + formMultipartBinding struct{} +) + +func (formBinding) Name() string { + return "form" +} + +func (formBinding) Bind(req *http.Request, obj any) error { + if err := req.ParseForm(); err != nil { + return err + } + if err := req.ParseMultipartForm(defaultMemory); err != nil && !errors.Is(err, http.ErrNotMultipart) { + return err + } + if err := mapForm(obj, req.Form); err != nil { + return err + } + return validate(obj) +} + +func (formPostBinding) Name() string { + return "form-urlencoded" +} + +func (formPostBinding) Bind(req *http.Request, obj any) error { + if err := req.ParseForm(); err != nil { + return err + } + if err := mapForm(obj, req.PostForm); err != nil { + return err + } + return validate(obj) +} + +func (formMultipartBinding) Name() string { + return "multipart/form-data" +} + +func (formMultipartBinding) Bind(req *http.Request, obj any) error { + if err := req.ParseMultipartForm(defaultMemory); err != nil { + return err + } + if err := mappingByPtr(obj, (*multipartRequest)(req), "form"); err != nil { + return err + } + + return validate(obj) +} diff --git a/internal/pkg/binding/form_mapping.go b/internal/pkg/binding/form_mapping.go new file mode 100644 index 0000000..eef249f --- /dev/null +++ b/internal/pkg/binding/form_mapping.go @@ -0,0 +1,473 @@ +package binding + +import ( + "encoding/json" + "errors" + "fmt" + "mime/multipart" + "reflect" + "strconv" + "strings" + "time" +) + +var ( + errUnknownType = errors.New("unknown type") + + // ErrConvertMapStringSlice can not convert to map[string][]string + ErrConvertMapStringSlice = errors.New("can not convert to map slices of strings") + + // ErrConvertToMapString can not convert to map[string]string + ErrConvertToMapString = errors.New("can not convert to map of strings") +) + +func mapURI(ptr any, m map[string][]string) error { + return mapFormByTag(ptr, m, "uri") +} + +func mapForm(ptr any, form map[string][]string) error { + return mapFormByTag(ptr, form, "form") +} + +func MapFormWithTag(ptr any, form map[string][]string, tag string) error { + return mapFormByTag(ptr, form, tag) +} + +var emptyField = reflect.StructField{} + +func mapFormByTag(ptr any, form map[string][]string, tag string) error { + // Check if ptr is a map + ptrVal := reflect.ValueOf(ptr) + var pointed any + if ptrVal.Kind() == reflect.Ptr { + ptrVal = ptrVal.Elem() + pointed = ptrVal.Interface() + } + if ptrVal.Kind() == reflect.Map && + ptrVal.Type().Key().Kind() == reflect.String { + if pointed != nil { + ptr = pointed + } + return setFormMap(ptr, form) + } + + return mappingByPtr(ptr, formSource(form), tag) +} + +// setter tries to set value on a walking by fields of a struct +type setter interface { + TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSet bool, err error) +} + +type formSource map[string][]string + +var _ setter = formSource(nil) + +// TrySet tries to set a value by request's form source (like map[string][]string) +func (form formSource) TrySet(value reflect.Value, field reflect.StructField, tagValue string, opt setOptions) (isSet bool, err error) { + return setByForm(value, field, form, tagValue, opt) +} + +func mappingByPtr(ptr any, setter setter, tag string) error { + _, err := mapping(reflect.ValueOf(ptr), emptyField, setter, tag) + return err +} + +func mapping(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) { + if field.Tag.Get(tag) == "-" { // just ignoring this field + return false, nil + } + + vKind := value.Kind() + + if vKind == reflect.Ptr { + var isNew bool + vPtr := value + if value.IsNil() { + isNew = true + vPtr = reflect.New(value.Type().Elem()) + } + isSet, err := mapping(vPtr.Elem(), field, setter, tag) + if err != nil { + return false, err + } + if isNew && isSet { + value.Set(vPtr) + } + return isSet, nil + } + + if vKind != reflect.Struct || !field.Anonymous { + ok, err := tryToSetValue(value, field, setter, tag) + if err != nil { + return false, err + } + if ok { + return true, nil + } + } + + if vKind == reflect.Struct { + tValue := value.Type() + + var isSet bool + for i := 0; i < value.NumField(); i++ { + sf := tValue.Field(i) + if sf.PkgPath != "" && !sf.Anonymous { // unexported + continue + } + ok, err := mapping(value.Field(i), sf, setter, tag) + if err != nil { + return false, err + } + isSet = isSet || ok + } + return isSet, nil + } + return false, nil +} + +type setOptions struct { + isDefaultExists bool + defaultValue string +} + +func tryToSetValue(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) { + var tagValue string + var setOpt setOptions + + tagValue = field.Tag.Get(tag) + tagValue, opts := head(tagValue, ",") + + if tagValue == "" { // default value is FieldName + tagValue = field.Name + } + if tagValue == "" { // when field is "emptyField" variable + return false, nil + } + + var opt string + for len(opts) > 0 { + opt, opts = head(opts, ",") + + if k, v := head(opt, "="); k == "default" { + setOpt.isDefaultExists = true + setOpt.defaultValue = v + } + } + + return setter.TrySet(value, field, tagValue, setOpt) +} + +// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. +type BindUnmarshaler interface { + // UnmarshalParam decodes and assigns a value from an form or query param. + UnmarshalParam(param string) error +} + +// trySetCustom tries to set a custom type value +// If the value implements the BindUnmarshaler interface, it will be used to set the value, we will return `true` +// to skip the default value setting. +func trySetCustom(val string, value reflect.Value) (isSet bool, err error) { + switch v := value.Addr().Interface().(type) { + case BindUnmarshaler: + return true, v.UnmarshalParam(val) + } + return false, nil +} + +func trySplit(vs []string, field reflect.StructField) (newVs []string, err error) { + cfTag := field.Tag.Get("collection_format") + if cfTag == "" || cfTag == "multi" { + return vs, nil + } + + var sep string + switch cfTag { + case "csv": + sep = "," + case "ssv": + sep = " " + case "tsv": + sep = "\t" + case "pipes": + sep = "|" + default: + return vs, fmt.Errorf("%s is not supported in the collection_format. (csv, ssv, pipes)", cfTag) + } + + totalLength := 0 + for _, v := range vs { + totalLength += strings.Count(v, sep) + 1 + } + newVs = make([]string, 0, totalLength) + for _, v := range vs { + newVs = append(newVs, strings.Split(v, sep)...) + } + + return newVs, nil +} + +func setByForm(value reflect.Value, field reflect.StructField, form map[string][]string, tagValue string, opt setOptions) (isSet bool, err error) { + vs, ok := form[tagValue] + if !ok && !opt.isDefaultExists { + return false, nil + } + + switch value.Kind() { + case reflect.Slice: + if !ok { + vs = []string{opt.defaultValue} + } + + if ok, err = trySetCustom(vs[0], value); ok { + return ok, err + } + + if vs, err = trySplit(vs, field); err != nil { + return false, err + } + + return true, setSlice(vs, value, field) + case reflect.Array: + if !ok { + vs = []string{opt.defaultValue} + } + + if ok, err = trySetCustom(vs[0], value); ok { + return ok, err + } + + if vs, err = trySplit(vs, field); err != nil { + return false, err + } + + if len(vs) != value.Len() { + return false, fmt.Errorf("%q is not valid value for %s", vs, value.Type().String()) + } + + return true, setArray(vs, value, field) + default: + var val string + if !ok { + val = opt.defaultValue + } + + if len(vs) > 0 { + val = vs[0] + } + if ok, err := trySetCustom(val, value); ok { + return ok, err + } + return true, setWithProperType(val, value, field) + } +} + +func setWithProperType(val string, value reflect.Value, field reflect.StructField) error { + switch value.Kind() { + case reflect.Int: + return setIntField(val, 0, value) + case reflect.Int8: + return setIntField(val, 8, value) + case reflect.Int16: + return setIntField(val, 16, value) + case reflect.Int32: + return setIntField(val, 32, value) + case reflect.Int64: + switch value.Interface().(type) { + case time.Duration: + return setTimeDuration(val, value) + } + return setIntField(val, 64, value) + case reflect.Uint: + return setUintField(val, 0, value) + case reflect.Uint8: + return setUintField(val, 8, value) + case reflect.Uint16: + return setUintField(val, 16, value) + case reflect.Uint32: + return setUintField(val, 32, value) + case reflect.Uint64: + return setUintField(val, 64, value) + case reflect.Bool: + return setBoolField(val, value) + case reflect.Float32: + return setFloatField(val, 32, value) + case reflect.Float64: + return setFloatField(val, 64, value) + case reflect.String: + value.SetString(val) + case reflect.Struct: + switch value.Interface().(type) { + case time.Time: + return setTimeField(val, field, value) + case multipart.FileHeader: + return nil + } + return json.Unmarshal(stringToBytes(val), value.Addr().Interface()) + case reflect.Map: + return json.Unmarshal(stringToBytes(val), value.Addr().Interface()) + case reflect.Ptr: + if !value.Elem().IsValid() { + value.Set(reflect.New(value.Type().Elem())) + } + return setWithProperType(val, value.Elem(), field) + default: + return errUnknownType + } + return nil +} + +func setIntField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0" + } + intVal, err := strconv.ParseInt(val, 10, bitSize) + if err == nil { + field.SetInt(intVal) + } + return err +} + +func setUintField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0" + } + uintVal, err := strconv.ParseUint(val, 10, bitSize) + if err == nil { + field.SetUint(uintVal) + } + return err +} + +func setBoolField(val string, field reflect.Value) error { + if val == "" { + val = "false" + } + boolVal, err := strconv.ParseBool(val) + if err == nil { + field.SetBool(boolVal) + } + return err +} + +func setFloatField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0.0" + } + floatVal, err := strconv.ParseFloat(val, bitSize) + if err == nil { + field.SetFloat(floatVal) + } + return err +} + +func setTimeField(val string, structField reflect.StructField, value reflect.Value) error { + timeFormat := structField.Tag.Get("time_format") + if timeFormat == "" { + timeFormat = time.RFC3339 + } + + switch tf := strings.ToLower(timeFormat); tf { + case "unix", "unixnano": + tv, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return err + } + + d := time.Duration(1) + if tf == "unixnano" { + d = time.Second + } + + t := time.Unix(tv/int64(d), tv%int64(d)) + value.Set(reflect.ValueOf(t)) + return nil + } + + if val == "" { + value.Set(reflect.ValueOf(time.Time{})) + return nil + } + + l := time.Local + if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC { + l = time.UTC + } + + if locTag := structField.Tag.Get("time_location"); locTag != "" { + loc, err := time.LoadLocation(locTag) + if err != nil { + return err + } + l = loc + } + + t, err := time.ParseInLocation(timeFormat, val, l) + if err != nil { + return err + } + + value.Set(reflect.ValueOf(t)) + return nil +} + +func setArray(vals []string, value reflect.Value, field reflect.StructField) error { + for i, s := range vals { + err := setWithProperType(s, value.Index(i), field) + if err != nil { + return err + } + } + return nil +} + +func setSlice(vals []string, value reflect.Value, field reflect.StructField) error { + slice := reflect.MakeSlice(value.Type(), len(vals), len(vals)) + err := setArray(vals, slice, field) + if err != nil { + return err + } + value.Set(slice) + return nil +} + +func setTimeDuration(val string, value reflect.Value) error { + d, err := time.ParseDuration(val) + if err != nil { + return err + } + value.Set(reflect.ValueOf(d)) + return nil +} + +func head(str, sep string) (head string, tail string) { + head, tail, _ = strings.Cut(str, sep) + return head, tail +} + +func setFormMap(ptr any, form map[string][]string) error { + el := reflect.TypeOf(ptr).Elem() + + if el.Kind() == reflect.Slice { + ptrMap, ok := ptr.(map[string][]string) + if !ok { + return ErrConvertMapStringSlice + } + for k, v := range form { + ptrMap[k] = v + } + + return nil + } + + ptrMap, ok := ptr.(map[string]string) + if !ok { + return ErrConvertToMapString + } + for k, v := range form { + ptrMap[k] = v[len(v)-1] // pick last + } + + return nil +} diff --git a/internal/pkg/binding/init.go b/internal/pkg/binding/init.go new file mode 100644 index 0000000..bc1c84c --- /dev/null +++ b/internal/pkg/binding/init.go @@ -0,0 +1,61 @@ +package binding + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/go-playground/locales/en" + "github.com/go-playground/locales/zh" + ut "github.com/go-playground/universal-translator" + "github.com/go-playground/validator/v10" + enTranslations "github.com/go-playground/validator/v10/translations/en" + chTranslations "github.com/go-playground/validator/v10/translations/zh" +) + +var trans ut.Translator + +// loca 通常取决于 http 请求头的 'Accept-Language' +func SetValidatorTrans(local string) (err error) { + if v, ok := Validator.Engine().(*validator.Validate); ok { + zhT := zh.New() // chinese + enT := en.New() // english + uni := ut.New(enT, zhT, enT) + + var o bool + trans, o = uni.GetTranslator(local) + if !o { + return fmt.Errorf("uni.GetTranslator(%s) failed", local) + } + // register translate + // 注册翻译器 + switch local { + case "en": + err = enTranslations.RegisterDefaultTranslations(v, trans) + case "zh": + err = chTranslations.RegisterDefaultTranslations(v, trans) + default: + err = enTranslations.RegisterDefaultTranslations(v, trans) + } + return + } + return +} + +func ValidatorErrors(err error) string { + if errors, ok := err.(validator.ValidationErrors); ok { + errs := make(map[string]any) + for _, e := range errors { + errs[e.StructField()] = strings.Replace(e.Translate(trans), e.StructField(), "", -1) + } + + // 将 map 转换为 JSON 格式的字节切片 + jsonData, err := json.MarshalIndent(errs, "", " ") + if err != nil { + return err.Error() + } + return string(jsonData) + } else { + return err.Error() + } +} diff --git a/internal/pkg/binding/json.go b/internal/pkg/binding/json.go new file mode 100644 index 0000000..650c899 --- /dev/null +++ b/internal/pkg/binding/json.go @@ -0,0 +1,51 @@ +package binding + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" +) + +// EnableDecoderUseNumber is used to call the UseNumber method on the JSON +// Decoder instance. UseNumber causes the Decoder to unmarshal a number into an +// any as a Number instead of as a float64. +var EnableDecoderUseNumber = false + +// EnableDecoderDisallowUnknownFields is used to call the DisallowUnknownFields method +// on the JSON Decoder instance. DisallowUnknownFields causes the Decoder to +// return an error when the destination is a struct and the input contains object +// keys which do not match any non-ignored, exported fields in the destination. +var EnableDecoderDisallowUnknownFields = false + +type jsonBinding struct{} + +func (jsonBinding) Name() string { + return "json" +} + +func (jsonBinding) Bind(req *http.Request, obj any) error { + if req == nil || req.Body == nil { + return errors.New("invalid request") + } + return decodeJSON(req.Body, obj) +} + +func (jsonBinding) BindBody(body []byte, obj any) error { + return decodeJSON(bytes.NewReader(body), obj) +} + +func decodeJSON(r io.Reader, obj any) error { + decoder := json.NewDecoder(r) + if EnableDecoderUseNumber { + decoder.UseNumber() + } + if EnableDecoderDisallowUnknownFields { + decoder.DisallowUnknownFields() + } + if err := decoder.Decode(obj); err != nil { + return err + } + return validate(obj) +} diff --git a/internal/pkg/binding/multipart_form_mapping.go b/internal/pkg/binding/multipart_form_mapping.go new file mode 100644 index 0000000..4c05c4e --- /dev/null +++ b/internal/pkg/binding/multipart_form_mapping.go @@ -0,0 +1,70 @@ +package binding + +import ( + "errors" + "mime/multipart" + "net/http" + "reflect" +) + +type multipartRequest http.Request + +var _ setter = (*multipartRequest)(nil) + +var ( + // ErrMultiFileHeader multipart.FileHeader invalid + ErrMultiFileHeader = errors.New("unsupported field type for multipart.FileHeader") + + // ErrMultiFileHeaderLenInvalid array for []*multipart.FileHeader len invalid + ErrMultiFileHeaderLenInvalid = errors.New("unsupported len of array for []*multipart.FileHeader") +) + +// TrySet tries to set a value by the multipart request with the binding a form file +func (r *multipartRequest) TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (bool, error) { + if files := r.MultipartForm.File[key]; len(files) != 0 { + return setByMultipartFormFile(value, field, files) + } + + return setByForm(value, field, r.MultipartForm.Value, key, opt) +} + +func setByMultipartFormFile(value reflect.Value, field reflect.StructField, files []*multipart.FileHeader) (isSet bool, err error) { + switch value.Kind() { + case reflect.Ptr: + switch value.Interface().(type) { + case *multipart.FileHeader: + value.Set(reflect.ValueOf(files[0])) + return true, nil + } + case reflect.Struct: + switch value.Interface().(type) { + case multipart.FileHeader: + value.Set(reflect.ValueOf(*files[0])) + return true, nil + } + case reflect.Slice: + slice := reflect.MakeSlice(value.Type(), len(files), len(files)) + isSet, err = setArrayOfMultipartFormFiles(slice, field, files) + if err != nil || !isSet { + return isSet, err + } + value.Set(slice) + return true, nil + case reflect.Array: + return setArrayOfMultipartFormFiles(value, field, files) + } + return false, ErrMultiFileHeader +} + +func setArrayOfMultipartFormFiles(value reflect.Value, field reflect.StructField, files []*multipart.FileHeader) (isSet bool, err error) { + if value.Len() != len(files) { + return false, ErrMultiFileHeaderLenInvalid + } + for i := range files { + set, err := setByMultipartFormFile(value.Index(i), field, files[i:i+1]) + if err != nil || !set { + return set, err + } + } + return true, nil +} diff --git a/internal/pkg/binding/query.go b/internal/pkg/binding/query.go new file mode 100644 index 0000000..a1f01c0 --- /dev/null +++ b/internal/pkg/binding/query.go @@ -0,0 +1,17 @@ +package binding + +import "net/http" + +type queryBinding struct{} + +func (queryBinding) Name() string { + return "query" +} + +func (queryBinding) Bind(req *http.Request, obj any) error { + values := req.URL.Query() + if err := mapForm(obj, values); err != nil { + return err + } + return validate(obj) +} diff --git a/internal/pkg/convertor/path.go b/internal/pkg/convertor/path.go new file mode 100644 index 0000000..f18de3f --- /dev/null +++ b/internal/pkg/convertor/path.go @@ -0,0 +1,16 @@ +package convertor + +import ( + "sort" + "strings" + + "management/internal/pkg/sliceutil" +) + +func HandleParentPath(parentPath string) string { + parentPath = strings.ReplaceAll(parentPath, ",,", ",") + paths := sliceutil.RemoveDuplicatesWithMap(strings.Split(parentPath, ",")) + sort.Strings(paths) + parentPath = strings.Join(paths, ",") + "," + return parentPath +} diff --git a/internal/pkg/know/pearadmin/pearadmin.go b/internal/pkg/know/pearadmin/pearadmin.go index ca82195..7af91e0 100644 --- a/internal/pkg/know/pearadmin/pearadmin.go +++ b/internal/pkg/know/pearadmin/pearadmin.go @@ -26,7 +26,7 @@ var PearJson = &dto.PearConfig{ Max: "30", Index: dto.Index{ Id: "10", - Href: "", + Href: "/dashboard", Title: "首页", }, }, diff --git a/internal/pkg/middleware/authorize.go b/internal/pkg/middleware/authorize.go index c78e988..67a1ac7 100644 --- a/internal/pkg/middleware/authorize.go +++ b/internal/pkg/middleware/authorize.go @@ -11,6 +11,7 @@ import ( var defaultMenus = map[string]bool{ "/home.html": true, + "/dashboard": true, "/system/menus": true, "/upload/img": true, "/upload/file": true, diff --git a/internal/pkg/sliceutil/sliceutil.go b/internal/pkg/sliceutil/sliceutil.go new file mode 100644 index 0000000..e6f82aa --- /dev/null +++ b/internal/pkg/sliceutil/sliceutil.go @@ -0,0 +1,31 @@ +package sliceutil + +import "sort" + +// 使用 map 去除重复元素 +func RemoveDuplicatesWithMap[T comparable](slice []T) []T { + result := make([]T, 0, len(slice)) + seen := make(map[T]bool) + for _, v := range slice { + if _, ok := seen[v]; !ok { + seen[v] = true + result = append(result, v) + } + } + return result +} + +// 先排序再去重 +func RemoveDuplicatesWithSort(slice []int) []int { + if len(slice) == 0 { + return slice + } + sort.Ints(slice) + result := []int{slice[0]} + for i := 1; i < len(slice); i++ { + if slice[i] != slice[i-1] { + result = append(result, slice[i]) + } + } + return result +} diff --git a/internal/pkg/strutil/internal.go b/internal/pkg/strutil/internal.go index e750550..51a71d4 100644 --- a/internal/pkg/strutil/internal.go +++ b/internal/pkg/strutil/internal.go @@ -27,7 +27,7 @@ func splitWordsToLower(s string) []string { // upperIndex 获得一个int slice,其元素是一个字符串的所有大写字母索引 func upperIndex(s string) []int { var res []int - for i := 0; i < len(s); i++ { + for i := range len(s) { if 64 < s[i] && s[i] < 91 { res = append(res, i) } diff --git a/internal/pkg/tpl/html_method.go b/internal/pkg/tpl/html_method.go index 36accad..8505ada 100644 --- a/internal/pkg/tpl/html_method.go +++ b/internal/pkg/tpl/html_method.go @@ -53,5 +53,9 @@ func (r *render) Methods() map[string]any { return template.HTML(strings.Join(s, ",")) } + res["toString"] = func(b []byte) string { + return string(b) + } + return res } diff --git a/management b/management index de06472..3d01a5a 100755 Binary files a/management and b/management differ diff --git a/web/templates/manage/home/dashboard.tmpl b/web/templates/manage/home/dashboard.tmpl new file mode 100644 index 0000000..1f0a75b --- /dev/null +++ b/web/templates/manage/home/dashboard.tmpl @@ -0,0 +1 @@ +